diff options
Diffstat (limited to 'pkg/tcpip/network')
32 files changed, 525 insertions, 16759 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD deleted file mode 100644 index fa8814bac..000000000 --- a/pkg/tcpip/network/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "ip_test", - size = "small", - srcs = [ - "ip_test.go", - "multicast_group_test.go", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD deleted file mode 100644 index d59d678b2..000000000 --- a/pkg/tcpip/network/arp/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "arp", - srcs = [ - "arp.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "arp_test", - size = "small", - srcs = ["arp_test.go"], - deps = [ - ":arp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":arp", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go new file mode 100644 index 000000000..5cd8535e3 --- /dev/null +++ b/pkg/tcpip/network/arp/arp_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package arp diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go deleted file mode 100644 index 018d6a578..000000000 --- a/pkg/tcpip/network/arp/arp_test.go +++ /dev/null @@ -1,711 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp_test - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -const ( - nicID = 1 - - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - - unknownAddr = tcpip.Address("\x0a\x00\x00\x03") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // eventChanSize defines the size of event channels used by the neighbor - // cache's event dispatcher. The size chosen here needs to be sufficient to - // queue all the events received during tests before consumption. - // If eventChanSize is too small, the tests may deadlock. - eventChanSize = 32 -) - -type eventType uint8 - -const ( - entryAdded eventType = iota - entryChanged - entryRemoved -) - -func (t eventType) String() string { - switch t { - case entryAdded: - return "add" - case entryChanged: - return "change" - case entryRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -type eventInfo struct { - eventType eventType - nicID tcpip.NICID - entry stack.NeighborEntry -} - -func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) -} - -// arpDispatcher implements NUDDispatcher to validate the dispatching of -// events upon certain NUD state machine events. -type arpDispatcher struct { - // C is where events are queued - C chan eventInfo -} - -var _ stack.NUDDispatcher = (*arpDispatcher)(nil) - -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryChanged, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryRemoved, - nicID: nicID, - entry: entry, - } - d.C <- e -} - -func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { - select { - case got := <-d.C: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - case <-ctx.Done(): - return fmt.Errorf("%s for %s", ctx.Err(), want) - } - return nil -} - -func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.waitForEvent(ctx, want) -} - -func (d *arpDispatcher) nextEvent() (eventInfo, bool) { - select { - case event := <-d.C: - return event, true - default: - return eventInfo{}, false - } -} - -type testContext struct { - s *stack.Stack - linkEP *channel.Endpoint - nudDisp *arpDispatcher -} - -func newTestContext(t *testing.T) *testContext { - c := stack.DefaultNUDConfigurations() - // Transition from Reachable to Stale almost immediately to test if receiving - // probes refreshes positive reachability. - c.BaseReachableTime = time.Microsecond - - d := arpDispatcher{ - // Create an event channel large enough so the neighbor cache doesn't block - // while dispatching events. Blocking could interfere with the timing of - // NUD transitions. - C: make(chan eventInfo, eventChanSize), - } - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - NUDConfigs: c, - NUDDisp: &d, - }) - - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - return &testContext{ - s: s, - linkEP: ep, - nudDisp: &d, - } -} - -func (c *testContext) cleanup() { - c.linkEP.Close() -} - -func TestMalformedPacket(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - v := make(buffer.View, header.ARPSize) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) - if err != nil { - t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err) - } - ep.Disable() - - v := make(buffer.View, header.ARPSize) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDirectReply(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPReply) - - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - copy(h.HardwareAddressTarget(), stackLinkAddr) - copy(h.ProtocolAddressTarget(), stackAddr) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - - c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } - if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) - } -} - -func TestDirectRequest(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - tests := []struct { - name string - senderAddr tcpip.Address - senderLinkAddr tcpip.LinkAddress - targetAddr tcpip.Address - isValid bool - }{ - { - name: "Loopback", - senderAddr: stackAddr, - senderLinkAddr: stackLinkAddr, - targetAddr: stackAddr, - isValid: true, - }, - { - name: "Remote", - senderAddr: remoteAddr, - senderLinkAddr: remoteLinkAddr, - targetAddr: stackAddr, - isValid: true, - }, - { - name: "RemoteInvalidTarget", - senderAddr: remoteAddr, - senderLinkAddr: remoteLinkAddr, - targetAddr: unknownAddr, - isValid: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - packetsRecv := c.s.Stats().ARP.PacketsReceived.Value() - requestsRecv := c.s.Stats().ARP.RequestsReceived.Value() - requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() - outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value() - - // Inject an incoming ARP request. - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), test.senderLinkAddr) - copy(h.ProtocolAddressSender(), test.senderAddr) - copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - })) - - if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) - } - if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) - } - - if !test.isValid { - // No packets should be sent after receiving an invalid ARP request. - // There is no need to perform a blocking read here, since packets are - // sent in the same function that handles ARP requests. - if pkt, ok := c.linkEP.Read(); ok { - t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) - } - if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want { - t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want) - } - if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) - } - - return - } - - if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) - } - - // Verify an ARP response was sent. - pi, ok := c.linkEP.Read() - if !ok { - t.Fatal("expected ARP response to be sent, got none") - } - - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.NetworkHeader().View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want) - } - - // Verify the sender was saved in the neighbor cache. - wantEvent := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), - State: stack.Stale, - }, - } - if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { - t.Fatal(err) - } - - neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) - if err != nil { - t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff) - } - t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr) - } - neighborByAddr[n.Addr] = n - } - - neigh, ok := neighborByAddr[test.senderAddr] - if !ok { - t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr) - } - if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { - t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) - } - if got, want := neigh.State, stack.Stale; got != want { - t.Errorf("got neighbor State = %s, want = %s", got, want) - } - - // No more events should be dispatched - for { - event, ok := c.nudDisp.nextEvent() - if !ok { - break - } - t.Errorf("unexpected %s", event) - } - }) - } -} - -var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) - -type testLinkEndpoint struct { - stack.LinkEndpoint - - writeErr tcpip.Error -} - -func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if t.writeErr != nil { - return t.writeErr - } - - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) -} - -func TestLinkAddressRequest(t *testing.T) { - const nicID = 1 - - testAddr := tcpip.Address([]byte{1, 2, 3, 4}) - - tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - linkErr tcpip.Error - expectedErr tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - expectedRequestsSent uint64 - expectedRequestBadLocalAddressErrors uint64 - expectedRequestInterfaceHasNoLocalAddressErrors uint64 - expectedRequestDroppedErrors uint64 - }{ - { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - localAddr: "", - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - localAddr: "", - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with unassigned address", - nicAddr: stackAddr, - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: &tcpip.ErrBadLocalAddress{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with unassigned address", - nicAddr: stackAddr, - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: &tcpip.ErrBadLocalAddress{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 0, - }, - { - name: "Unicast with no local address available", - nicAddr: "", - localAddr: "", - remoteLinkAddr: remoteLinkAddr, - expectedErr: &tcpip.ErrNetworkUnreachable{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 1, - expectedRequestDroppedErrors: 0, - }, - { - name: "Multicast with no local address available", - nicAddr: "", - localAddr: "", - remoteLinkAddr: "", - expectedErr: &tcpip.ErrNetworkUnreachable{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 1, - expectedRequestDroppedErrors: 0, - }, - { - name: "Link error", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - linkErr: &tcpip.ErrInvalidEndpointState{}, - expectedErr: &tcpip.ErrInvalidEndpointState{}, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestInterfaceHasNoLocalAddressErrors: 0, - expectedRequestDroppedErrors: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) - } - linkRes, ok := ep.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) - } - - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) - } - } - - { - err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) - } - } - - if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { - t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) - } - if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors) - } - if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) - } - if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) - } - - if test.expectedErr != nil { - return - } - - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - - rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) - } - }) - } -} - -func TestDADARPRequestPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - }, - }), ipv4.NewProtocol}, - }) - e := channel.New(1, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err) - } else if res != stack.DADStarting { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) - } - - pkt, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatal("expected to send an ARP request") - } - - if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - - req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if !req.IsValid() { - t.Errorf("got req.IsValid() = false, want = true") - } - if got := req.Op(); got != header.ARPRequest { - t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr) - } - if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any { - t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any) - } - if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { - t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want) - } - if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr { - t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go deleted file mode 100644 index e867b3c3f..000000000 --- a/pkg/tcpip/network/arp/stats_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD deleted file mode 100644 index 872165866..000000000 --- a/pkg/tcpip/network/hash/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hash", - srcs = ["hash.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go new file mode 100644 index 000000000..9467fe298 --- /dev/null +++ b/pkg/tcpip/network/hash/hash_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hash diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD deleted file mode 100644 index 274f09092..000000000 --- a/pkg/tcpip/network/internal/fragmentation/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - visibility = [ - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/log", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "fragmentation_test.go", - "reassembler_test.go", - ], - library = ":fragmentation", - deps = [ - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go new file mode 100644 index 000000000..3f82c184a --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,64 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (l *reassemblerList) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerList" +} + +func (l *reassemblerList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *reassemblerList) beforeSave() {} + +func (l *reassemblerList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *reassemblerList) afterLoad() {} + +func (l *reassemblerList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *reassemblerEntry) StateTypeName() string { + return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry" +} + +func (e *reassemblerEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *reassemblerEntry) beforeSave() {} + +func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *reassemblerEntry) afterLoad() {} + +func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*reassemblerList)(nil)) + state.Register((*reassemblerEntry)(nil)) +} diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go deleted file mode 100644 index 7daf64b4a..000000000 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,636 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "errors" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// reassembleTimeout is dummy timeout used for testing, where the clock never -// advances. -const reassembleTimeout = 1 - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -func pkt(size int, pieces ...string) *stack.PacketBuffer { - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv(size, pieces...), - }) -} - -type processInput struct { - id FragmentID - first uint16 - last uint16 - more bool - proto uint8 - pkt *stack.PacketBuffer -} - -type processOutput struct { - vv buffer.VectorisedView - proto uint8 - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Next Header protocol mismatch", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), proto: 6, done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) - firstFragmentProto := c.in[0].proto - for i, in := range c.in { - resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) - if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, err) - } - if done != c.out[i].done { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", - in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) - } - if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", - in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) - } - if firstFragmentProto != proto { - t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", - in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) - } - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - const ( - reassemblyTimeout = time.Millisecond - protocol = 0xff - ) - - type fragment struct { - first uint16 - last uint16 - more bool - data string - } - - type event struct { - // name is a nickname of this event. - name string - - // clockAdvance is a duration to advance the clock. The clock advances - // before a fragment specified in the fragment field is processed. - clockAdvance time.Duration - - // fragment is a fragment to process. This can be nil if there is no - // fragment to process. - fragment *fragment - - // expectDone is true if the fragmentation instance should report the - // reassembly is done after the fragment is processd. - expectDone bool - - // memSizeAfterEvent is the expected memory size of the fragmentation - // instance after the event. - memSizeAfterEvent int - } - - memSizeOfFrags := func(frags ...*fragment) int { - var size int - for _, frag := range frags { - size += pkt(len(frag.data), frag.data).MemSize() - } - return size - } - - half1 := &fragment{first: 0, last: 0, more: true, data: "0"} - half2 := &fragment{first: 1, last: 1, more: false, data: "1"} - - tests := []struct { - name string - events []event - }{ - { - name: "half1 and half2 are reassembled successfully", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half2", - fragment: half2, - expectDone: true, - memSizeAfterEvent: 0, - }, - }, - }, - { - name: "half1 timeout, half2 timeout", - events: []event{ - { - name: "half1", - fragment: half1, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half1), - }, - { - name: "half1 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - { - name: "half2", - fragment: half2, - expectDone: false, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - memSizeAfterEvent: memSizeOfFrags(half2), - }, - { - name: "half2 reassembly timeout", - clockAdvance: 1, - memSizeAfterEvent: 0, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) - for _, event := range test.events { - clock.Advance(event.clockAdvance) - if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) - if err != nil { - t.Fatalf("%s: f.Process failed: %s", event.name, err) - } - if done != event.expectDone { - t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) - } - } - if got, want := f.memSize, event.memSizeAfterEvent; got != want { - t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) - } - } - }) - } -} - -func TestMemoryLimits(t *testing.T) { - lowLimit := pkt(1, "0").MemSize() - highLimit := 3 * lowLimit // Allow at most 3 such packets. - f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) - // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) - // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) - - if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - memSize := pkt(1, "0").MemSize() - f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) - // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - - if got, want := f.memSize, memSize; got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} - -func TestErrors(t *testing.T) { - tests := []struct { - name string - blockSize uint16 - first uint16 - last uint16 - more bool - data string - err error - }{ - { - name: "exact block size without more", - blockSize: 2, - first: 2, - last: 3, - more: false, - data: "01", - }, - { - name: "exact block size with more", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "01", - }, - { - name: "exact block size with more and extra data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "012", - err: ErrInvalidArgs, - }, - { - name: "exact block size with more and too little data", - blockSize: 2, - first: 2, - last: 3, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size with more", - blockSize: 2, - first: 2, - last: 2, - more: true, - data: "0", - err: ErrInvalidArgs, - }, - { - name: "not exact block size without more", - blockSize: 2, - first: 2, - last: 2, - more: false, - data: "0", - }, - { - name: "first not a multiple of block size", - blockSize: 2, - first: 3, - last: 4, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - { - name: "first more than last", - blockSize: 2, - first: 4, - last: 3, - more: true, - data: "01", - err: ErrInvalidArgs, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) - if !errors.Is(err, test.err) { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) - } - if done { - t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) - } - }) - } -} - -type fragmentInfo struct { - remaining int - copied int - offset int - more bool -} - -func TestPacketFragmenter(t *testing.T) { - const ( - reserve = 60 - proto = 0 - ) - - tests := []struct { - name string - fragmentPayloadLen uint32 - transportHeaderLen int - payloadSize int - wantFragments []fragmentInfo - }{ - { - name: "Packet exactly fits in MTU", - fragmentPayloadLen: 1280, - transportHeaderLen: 0, - payloadSize: 1280, - wantFragments: []fragmentInfo{ - {remaining: 0, copied: 1280, offset: 0, more: false}, - }, - }, - { - name: "Packet exactly does not fit in MTU", - fragmentPayloadLen: 1000, - transportHeaderLen: 0, - payloadSize: 1001, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 1000, offset: 0, more: true}, - {remaining: 0, copied: 1, offset: 1000, more: false}, - }, - }, - { - name: "Packet has a transport header", - fragmentPayloadLen: 560, - transportHeaderLen: 40, - payloadSize: 560, - wantFragments: []fragmentInfo{ - {remaining: 1, copied: 560, offset: 0, more: true}, - {remaining: 0, copied: 40, offset: 560, more: false}, - }, - }, - { - name: "Packet has a huge transport header", - fragmentPayloadLen: 500, - transportHeaderLen: 1300, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {remaining: 3, copied: 500, offset: 0, more: true}, - {remaining: 2, copied: 500, offset: 500, more: true}, - {remaining: 1, copied: 500, offset: 1000, more: true}, - {remaining: 0, copied: 300, offset: 1500, more: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - originalPayload := stack.PayloadSince(pkt.TransportHeader()) - var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) - for i := 0; ; i++ { - fragPkt, offset, copied, more := pf.BuildNextFragment() - wantFragment := test.wantFragments[i] - if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { - t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) - } - if copied != wantFragment.copied { - t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) - } - if offset != wantFragment.offset { - t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) - } - if more != wantFragment.more { - t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) - } - if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) - } - if got := fragPkt.AvailableHeaderBytes(); got != reserve { - t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) - } - if got := fragPkt.TransportHeader().View().Size(); got != 0 { - t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) - } - reassembledPayload.AppendViews(fragPkt.Data().Views()) - if !more { - if i != len(test.wantFragments)-1 { - t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) - } - break - } - } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" { - t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type testTimeoutHandler struct { - pkt *stack.PacketBuffer -} - -func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { - h.pkt = pkt -} - -func TestTimeoutHandler(t *testing.T) { - const ( - proto = 99 - ) - - pk1 := pkt(1, "1") - pk2 := pkt(1, "2") - - type processParam struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - } - - tests := []struct { - name string - params []processParam - wantError bool - wantPkt *stack.PacketBuffer - }{ - { - name: "onTimeout runs", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "no first fragment", - params: []processParam{ - { - first: 1, - last: 1, - more: true, - pkt: pk1, - }, - }, - wantError: false, - wantPkt: nil, - }, - { - name: "second pkt is ignored", - params: []processParam{ - { - first: 0, - last: 0, - more: true, - pkt: pk1, - }, - { - first: 0, - last: 0, - more: true, - pkt: pk2, - }, - }, - wantError: false, - wantPkt: pk1, - }, - { - name: "invalid args - first is greater than last", - params: []processParam{ - { - first: 1, - last: 0, - more: true, - pkt: pk1, - }, - }, - wantError: true, - wantPkt: nil, - }, - } - - id := FragmentID{ID: 0} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - handler := &testTimeoutHandler{pkt: nil} - - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - - for _, p := range test.params { - if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { - t.Errorf("f.Process error = %s", err) - } - } - if !test.wantError { - r, ok := f.reassemblers[id] - if !ok { - t.Fatal("Reassembler not found") - } - f.release(r, true) - } - switch { - case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView()) - case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView()) - case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" { - t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go new file mode 100644 index 000000000..673bb11b0 --- /dev/null +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go @@ -0,0 +1,221 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *reassemblerList) Len() (count int) { + for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *reassemblerList) PushFront(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *reassemblerList) PushBack(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + bLinker := reassemblerElementMapper{}.linkerFor(b) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + aLinker := reassemblerElementMapper{}.linkerFor(a) + eLinker := reassemblerElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *reassemblerList) Remove(e *reassembler) { + linker := reassemblerElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go deleted file mode 100644 index cfd9f00ef..000000000 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "bytes" - "math" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type processParams struct { - first uint16 - last uint16 - more bool - pkt *stack.PacketBuffer - wantDone bool - wantError error -} - -func TestReassemblerProcess(t *testing.T) { - const proto = 99 - - v := func(size int) buffer.View { - payload := buffer.NewView(size) - for i := 1; i < size; i++ { - payload[i] = uint8(i) * 3 - } - return payload - } - - pkt := func(sizes ...int) *stack.PacketBuffer { - var vv buffer.VectorisedView - for _, size := range sizes { - vv.AppendView(v(size)) - } - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - } - - var tests = []struct { - name string - params []processParams - want []hole - wantPkt *stack.PacketBuffer - }{ - { - name: "No fragments", - params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, - }, - { - name: "One fragment at beginning", - params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, - {first: 2, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment in the middle", - params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, - {first: 0, last: 0, filled: false, final: false}, - {first: 3, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "One fragment at the end", - params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, - want: []hole{ - {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, - {first: 0, last: 0, filled: false}, - }, - }, - { - name: "One fragment completing a packet", - params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, - want: []hole{ - {first: 0, last: 1, filled: true, final: true}, - }, - wantPkt: pkt(2), - }, - { - name: "Two fragments completing a packet", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a duplicate", - params: []processParams{ - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 1, filled: true, final: false}, - {first: 2, last: 3, filled: true, final: true}, - }, - wantPkt: pkt(2, 2), - }, - { - name: "Two fragments completing a packet with a partial duplicate", - params: []processParams{ - {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, - {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, - {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, - }, - want: []hole{ - {first: 0, last: 3, filled: true, final: false}, - {first: 4, last: 5, filled: true, final: true}, - }, - wantPkt: pkt(4, 2), - }, - { - name: "Two overlapping fragments", - params: []processParams{ - {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, - {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, - }, - want: []hole{ - {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, - {first: 11, last: math.MaxUint16, filled: false, final: true}, - }, - }, - { - name: "Two final fragments with different ends", - params: []processParams{ - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, - {first: 0, last: 9, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - { - name: "Two final fragments - duplicate, with different ends", - params: []processParams{ - {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, - {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, - }, - want: []hole{ - {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, - {first: 0, last: 4, filled: false, final: false}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - var resPkt *stack.PacketBuffer - var isDone bool - for _, param := range test.params { - pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) - if done != param.wantDone || err != param.wantError { - t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) - } - if done { - resPkt = pkt - isDone = true - } - } - - ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } - cmpPktData := func(a, b *stack.PacketBuffer) bool { - if a == nil || b == nil { - return a == b - } - return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView()) - } - - if isDone { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - // Do not compare pkt in hole. Data will be altered. - cmp.Comparer(ignorePkt), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { - t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) - } - } else { - if diff := cmp.Diff( - test.want, r.holes, - cmp.AllowUnexported(hole{}), - cmp.Comparer(cmpPktData), - ); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) - } - } - }) - } -} diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD deleted file mode 100644 index d21b4c7ef..000000000 --- a/pkg/tcpip/network/internal/ip/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ip", - srcs = [ - "duplicate_address_detection.go", - "generic_multicast_protocol.go", - "stats.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ip_x_test", - size = "small", - srcs = [ - "duplicate_address_detection_test.go", - "generic_multicast_protocol_test.go", - ], - deps = [ - ":ip", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/faketime", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go deleted file mode 100644 index e00aa4678..000000000 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type mockDADProtocol struct { - t *testing.T - - mu struct { - sync.Mutex - - dad ip.DAD - sendCount map[tcpip.Address]int - } -} - -func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) { - m.mu.Lock() - defer m.mu.Unlock() - - m.t = t - opts.Protocol = m - m.mu.dad.Init(&m.mu, c, opts) - m.initLocked() -} - -func (m *mockDADProtocol) initLocked() { - m.mu.sendCount = make(map[tcpip.Address]int) -} - -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.sendCount[addr]++ - return nil -} - -func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendCount := make(map[tcpip.Address]int) - for _, a := range addrs { - sendCount[a]++ - } - - diff := cmp.Diff(sendCount, m.mu.sendCount) - m.initLocked() - return diff -} - -func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.dad.CheckDuplicateAddressLocked(addr, h) -} - -func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.StopLocked(addr, reason) -} - -func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.dad.SetConfigsLocked(c) -} - -const ( - addr1 = tcpip.Address("\x01") - addr2 = tcpip.Address("\x02") - addr3 = tcpip.Address("\x03") - addr4 = tcpip.Address("\x04") -) - -type dadResult struct { - Addr tcpip.Address - R stack.DADResult -} - -func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) { - return func(r stack.DADResult) { - ch <- dadResult{Addr: a, R: r} - } -} - -func TestDADCheckDuplicateAddress(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dad.init(t, stack.DADConfigurations{}, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 2) - - // DAD should initially be disabled. - if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) - } - // Wait for any initially fired timers to complete. - clock.Advance(0) - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Enable and request DAD. - dadConfigs1 := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs1) - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - // The second request for DAD on the same address should use the original - // request since it has not completed yet. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) - } - clock.Advance(0) - if diff := dad.check(nil); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dadConfigs2 := stack.DADConfigurations{ - DupAddrDetectTransmits: 2, - RetransmitTimer: time.Second, - } - dad.setConfigs(dadConfigs2) - // A new address should start a new DAD process. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - // Make sure DAD for addr1 only resolves after the expected timeout. - const delta = time.Nanosecond - dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer - clock.Advance(dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r) - default: - } - clock.Advance(delta) - for i := 0; i < 2; i++ { - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) - } - } - - // Make sure DAD for addr2 only resolves after the expected timeout. - dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer - clock.Advance(dadConfig2Duration - dadConfig1Duration - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r) - default: - } - clock.Advance(delta) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for addr2 after it resolved. - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadConfig2Duration) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore results. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} - -func TestDADStop(t *testing.T) { - var dad mockDADProtocol - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - dad.init(t, dadConfigs, ip.DADOptions{ - Clock: clock, - }) - - ch := make(chan dadResult, 1) - - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr1, &stack.DADAborted{}) - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dad.stop(addr2, &stack.DADDupAddrDetected{}) - if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should be able to restart DAD for an address we stopped DAD on. - if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { - t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) - } - clock.Advance(0) - if diff := dad.check([]tcpip.Address{addr1}); diff != "" { - t.Errorf("dad check mismatch (-want +got):\n%s", diff) - } - clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { - t.Errorf("dad result mismatch (-want +got):\n%s", diff) - } - - // Should not have anymore updates. - select { - case r := <-ch: - t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } -} diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go deleted file mode 100644 index 381460c82..000000000 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ /dev/null @@ -1,805 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" -) - -const maxUnsolicitedReportDelay = time.Second - -var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) - -type mockMulticastGroupProtocolProtectedFields struct { - sync.RWMutex - - genericMulticastGroup ip.GenericMulticastProtocolState - sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int - makeQueuePackets bool - disabled bool -} - -type mockMulticastGroupProtocol struct { - t *testing.T - - mu mockMulticastGroupProtocolProtectedFields -} - -func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { - m.mu.Lock() - defer m.mu.Unlock() - m.initLocked() - opts.Protocol = m - m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) -} - -func (m *mockMulticastGroupProtocol) initLocked() { - m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) -} - -func (m *mockMulticastGroupProtocol) setEnabled(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.disabled = !v -} - -func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.makeQueuePackets = v -} - -func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.JoinGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { - m.mu.Lock() - defer m.mu.Unlock() - return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleReportLocked(addr) -} - -func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) -} - -func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) -} - -func (m *mockMulticastGroupProtocol) makeAllNonMember() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.MakeAllNonMemberLocked() -} - -func (m *mockMulticastGroupProtocol) initializeGroups() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.InitializeGroupsLocked() -} - -func (m *mockMulticastGroupProtocol) sendQueuedReports() { - m.mu.Lock() - defer m.mu.Unlock() - m.mu.genericMulticastGroup.SendQueuedReportsLocked() -} - -// Enabled implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be read locked. -func (m *mockMulticastGroupProtocol) Enabled() bool { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") - } - - return !m.mu.disabled -} - -// SendReport implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) - } - - m.mu.sendReportGroupAddrCount[groupAddress]++ - return !m.mu.makeQueuePackets, nil -} - -// SendLeave implements ip.MulticastGroupProtocol. -// -// Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { - if m.mu.TryLock() { - m.mu.Unlock() - m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - if m.mu.TryRLock() { - m.mu.RUnlock() - m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) - } - - m.mu.sendLeaveGroupAddrCount[groupAddress]++ - return nil -} - -func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendReportGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendReportGroupAddresses { - sendReportGroupAddrCount[a] = 1 - } - - sendLeaveGroupAddrCount := make(map[tcpip.Address]int) - for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddrCount[a] = 1 - } - - diff := cmp.Diff( - &mockMulticastGroupProtocol{ - mu: mockMulticastGroupProtocolProtectedFields{ - sendReportGroupAddrCount: sendReportGroupAddrCount, - sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, - }, - }, - m, - cmp.AllowUnexported(mockMulticastGroupProtocol{}), - cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), - // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t - cmp.FilterPath( - func(p cmp.Path) bool { - switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": - return true - } - return false - }, - cmp.Ignore(), - ), - ) - m.initLocked() - return diff -} - -func TestJoinGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendReports bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendReports: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendReports: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(0)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, - }) - - // Joining a group should send a report immediately and another after - // a random interval between 0 and the maximum unsolicited report delay. - mgp.joinGroup(test.addr) - if test.shouldSendReports { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestLeaveGroup(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - shouldSendMessages bool - }{ - { - name: "Normal group", - addr: addr1, - shouldSendMessages: true, - }, - { - name: "All-nodes group", - addr: addr2, - shouldSendMessages: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(1)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, - }) - - mgp.joinGroup(test.addr) - if test.shouldSendMessages { - if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Leaving a group should send a leave report immediately and cancel any - // delayed reports. - { - - if !mgp.leaveGroup(test.addr) { - t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) - } - } - if test.shouldSendMessages { - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleReport(t *testing.T) { - tests := []struct { - name string - reportAddr tcpip.Address - expectReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - reportAddr: "", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Unpecified any", - reportAddr: "\x00", - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified", - reportAddr: addr1, - expectReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - reportAddr: addr3, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - reportAddr: addr4, - expectReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(2)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report for a group we have a timer scheduled for should - // cancel our delayed report timer for the group. - mgp.handleReport(test.reportAddr) - if len(test.expectReportsFor) != 0 { - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestHandleQuery(t *testing.T) { - tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectQueriedReportsFor []tcpip.Address - expectDelayedReportsFor []tcpip.Address - }{ - { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, - expectDelayedReportsFor: nil, - }, - { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectQueriedReportsFor: []tcpip.Address{addr1}, - expectDelayedReportsFor: []tcpip.Address{addr2}, - }, - { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectQueriedReportsFor: nil, - expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should make us reschedule our delayed report timer - // to some time within the new max response delay. - mgp.handleQuery(test.queryAddr, test.maxDelay) - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The groups that were not affected by the query should still send a - // report after the max unsolicited report delay. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestJoinCount(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: time.Second, - }) - - // Set the join count to 2 for a group. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // Only the first join should trigger a report to be sent. - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should still be considered joined after leaving once. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if !mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - // A leave report should only be sent once the join count reaches 0. - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Leaving once more should actually remove us from the group. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - - // Group should no longer be joined so we should not have anything to - // leave. - if mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - // - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, - }) - - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr3) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should send the leave reports for each but still consider them locally - // joined. - mgp.makeAllNonMember() - if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !mgp.isLocallyJoined(group) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) - } - } - - // Should send the initial set of unsolcited reports. - mgp.initializeGroups() - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -// TestGroupStateNonMember tests that groups do not send packets when in the -// non-member state, but are still considered locally joined. -func TestGroupStateNonMember(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() - - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - mgp.setEnabled(false) - - // Joining groups should not send any reports. - mgp.joinGroup(addr1) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.joinGroup(addr2) - if !mgp.isLocallyJoined(addr1) { - t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a query should not send any reports. - mgp.handleQuery(addr1, time.Nanosecond) - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Nanosecond) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Leaving groups should not send any leave messages. - if !mgp.leaveGroup(addr1) { - t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) - } - if mgp.isLocallyJoined(addr1) { - t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} - -func TestQueuedPackets(t *testing.T) { - clock := faketime.NewManualClock() - mgp := mockMulticastGroupProtocol{t: t} - mgp.init(ip.GenericMulticastProtocolOptions{ - Rand: rand.New(rand.NewSource(4)), - Clock: clock, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) - - // Joining should trigger a SendReport, but mgp should report that we did not - // send the packet. - mgp.setQueuePackets(true) - mgp.joinGroup(addr1) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report timer should have been cancelled since we did not send - // the initial report earlier. - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send (we should be idle). - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query but mock being unable to send reports again. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Mock being able to send reports again - we should have a packet queued to - // send. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receive a query again, but mock being unable to send reports. - mgp.setQueuePackets(true) - mgp.handleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Receiving a report should should transition us into the idle member state, - // even if we had a packet queued. We should no longer have any packets to - // send. - mgp.handleReport(addr1) - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // When we fail to send the initial set of reports, incoming reports should - // not affect a newly joined group's reports from being sent. - mgp.setQueuePackets(true) - mgp.joinGroup(addr2) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - mgp.handleReport(addr2) - // Attempting to send queued reports while still unable to send reports should - // not change the host state. - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Mock being able to successfully send the report. - mgp.setQueuePackets(false) - mgp.sendQueuedReports() - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // The delayed report (sent after the initial report) should now be sent. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - - // Should not have anything else to send. - mgp.sendQueuedReports() - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } -} diff --git a/pkg/tcpip/network/internal/ip/ip_state_autogen.go b/pkg/tcpip/network/internal/ip/ip_state_autogen.go new file mode 100644 index 000000000..aee77044e --- /dev/null +++ b/pkg/tcpip/network/internal/ip/ip_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package ip diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD deleted file mode 100644 index 1c4f583c7..000000000 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], - visibility = [ - "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/internal/fragmentation:__pkg__", - "//pkg/tcpip/network/ipv4:__pkg__", - "//pkg/tcpip/network/ipv6:__pkg__", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go deleted file mode 100644 index f5fa77b65..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil defines types and functions used to test Network Layer -// functionality such as IP fragmentation. -package testutil - -import ( - "fmt" - "math/rand" - "reflect" - "strings" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// MockLinkEndpoint is an endpoint used for testing, it stores packets written -// to it and can mock errors. -type MockLinkEndpoint struct { - // WrittenPackets is where packets written to the endpoint are stored. - WrittenPackets []*stack.PacketBuffer - - mtu uint32 - err tcpip.Error - allowPackets int -} - -// NewMockLinkEndpoint creates a new MockLinkEndpoint. -// -// err is the error that will be returned once allowPackets packets are written -// to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { - return &MockLinkEndpoint{ - mtu: mtu, - err: err, - allowPackets: allowPackets, - } -} - -// MTU implements LinkEndpoint.MTU. -func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } - -// Capabilities implements LinkEndpoint.Capabilities. -func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - var n int - - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { - return n, err - } - n++ - } - - return n, nil -} - -// Attach implements LinkEndpoint.Attach. -func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*MockLinkEndpoint) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*MockLinkEndpoint) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - -// MakeRandPkt generates a randomized packet. transportHeaderLength indicates -// how many random bytes will be copied in the Transport Header. -// extraHeaderReserveLength indicates how much extra space will be reserved for -// the other headers. The payload is made from Views of the sizes listed in -// viewSizes. -func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { - var views buffer.VectorisedView - - for _, s := range viewSizes { - newView := buffer.NewView(s) - if _, err := rand.Read(newView); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - views.AppendView(newView) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, - Data: views, - }) - pkt.NetworkProtocolNumber = proto - if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go deleted file mode 100644 index 5ff764800..000000000 --- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testutil - -import ( - "reflect" - "unsafe" -) - -// unsafeExposeUnexportedFields takes a Value and returns a version of it in -// which even unexported fields can be read and written. -func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { - return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go deleted file mode 100644 index aee1652fa..000000000 --- a/pkg/tcpip/network/ip_test.go +++ /dev/null @@ -1,1936 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") - ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") - ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") - ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") - localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") - ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - nicID = 1 -) - -var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: localIPv4Addr, - PrefixLen: 24, -} - -var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: localIPv6Addr, - PrefixLen: 120, -} - -type transportError struct { - origin tcpip.SockErrOrigin - typ uint8 - code uint8 - info uint32 - kind stack.TransportErrorKind -} - -// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. -// The former is used to pretend that it's a link endpoint so that we can -// inspect packets written by the network endpoints. The latter is used to -// pretend that it's the network stack so that it can inspect incoming packets -// that have been handled by the network endpoints. -// -// Packets are checked by comparing their fields/values against the expected -// values stored in the test object itself. -type testObject struct { - t *testing.T - protocol tcpip.TransportProtocolNumber - contents []byte - srcAddr tcpip.Address - dstAddr tcpip.Address - v4 bool - transErr transportError - - dataCalls int - controlCalls int -} - -// checkValues verifies that the transport protocol, data contents, src & dst -// addresses of a packet match what's expected. If any field doesn't match, the -// test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) { - if protocol != t.protocol { - t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) - } - - if srcAddr != t.srcAddr { - t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) - } - - if dstAddr != t.dstAddr { - t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) - } - - if len(v) != len(t.contents) { - t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) - } - - for i := range t.contents { - if t.contents[i] != v[i] { - t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) - } - } -} - -// DeliverTransportPacket is called by network endpoints after parsing incoming -// packets. This is used by the test object to verify that the results of the -// parsing are expected. -func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - netHdr := pkt.Network() - t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress()) - t.dataCalls++ - return stack.TransportPacketHandled -} - -// DeliverTransportError is called by network endpoints after parsing -// incoming control (ICMP) packets. This is used by the test object to verify -// that the results of the parsing are expected. -func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { - t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local) - if diff := cmp.Diff( - t.transErr, - transportError{ - origin: transErr.Origin(), - typ: transErr.Type(), - code: transErr.Code(), - info: transErr.Info(), - kind: transErr.Kind(), - }, - cmp.AllowUnexported(transportError{}), - ); diff != "" { - t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) - } - t.controlCalls++ -} - -// Attach is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (*testObject) IsAttached() bool { - return true -} - -// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that -// matches the linux loopback MTU. -func (*testObject) MTU() uint32 { - return 65536 -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*testObject) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (*testObject) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*testObject) Wait() {} - -// WritePacket is called by network endpoints after producing a packet and -// writing it to the link endpoint. This is used by the test object to verify -// that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - var prot tcpip.TransportProtocolNumber - var srcAddr tcpip.Address - var dstAddr tcpip.Address - - if t.v4 { - h := header.IPv4(pkt.NetworkHeader().View()) - prot = tcpip.TransportProtocolNumber(h.Protocol()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - - } else { - h := header.IPv6(pkt.NetworkHeader().View()) - prot = tcpip.TransportProtocolNumber(h.NextHeader()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - } - t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr) - return nil -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*testObject) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - panic("not implemented") -} - -func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv4.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - Gateway: ipv4Gateway, - NIC: 1, - }}) - - return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) -} - -func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - s.CreateNIC(nicID, loopback.New()) - s.AddAddress(nicID, ipv6.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: ipv6Gateway, - NIC: 1, - }}) - - return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) -} - -func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - e := channel.New(1, mtu, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) - } - - v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} - if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) - } - - return s, e -} - -func buildDummyStack(t *testing.T) *stack.Stack { - t.Helper() - - s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) - return s -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - testObject - - mu struct { - sync.RWMutex - disabled bool - } -} - -func (*testInterface) ID() tcpip.NICID { - return nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (t *testInterface) Enabled() bool { - t.mu.RLock() - defer t.mu.RUnlock() - return !t.mu.disabled -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (*testInterface) Spoofing() bool { - return false -} - -func (t *testInterface) setEnabled(v bool) { - t.mu.Lock() - defer t.mu.Unlock() - t.mu.disabled = !v -} - -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { - return nil -} - -func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { - return nil -} - -func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { - return false -} - -func TestSourceAddressValidation(t *testing.T) { - rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: src, - Dst: localIPv6Addr, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - tests := []struct { - name string - srcAddress tcpip.Address - rxICMP func(*channel.Endpoint, tcpip.Address) - valid bool - }{ - { - name: "IPv4 valid", - srcAddress: "\x01\x02\x03\x04", - rxICMP: rxIPv4ICMP, - valid: true, - }, - { - name: "IPv6 valid", - srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", - rxICMP: rxIPv6ICMP, - valid: true, - }, - { - name: "IPv4 unspecified", - srcAddress: header.IPv4Any, - rxICMP: rxIPv4ICMP, - valid: true, - }, - { - name: "IPv6 unspecified", - srcAddress: header.IPv4Any, - rxICMP: rxIPv6ICMP, - valid: true, - }, - { - name: "IPv4 multicast", - srcAddress: "\xe0\x00\x00\x01", - rxICMP: rxIPv4ICMP, - valid: false, - }, - { - name: "IPv6 multicast", - srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - rxICMP: rxIPv6ICMP, - valid: false, - }, - { - name: "IPv4 broadcast", - srcAddress: header.IPv4Broadcast, - rxICMP: rxIPv4ICMP, - valid: false, - }, - { - name: "IPv4 subnet broadcast", - srcAddress: func() tcpip.Address { - subnet := localIPv4AddrWithPrefix.Subnet() - return subnet.Broadcast() - }(), - rxICMP: rxIPv4ICMP, - valid: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) - test.rxICMP(e, test.srcAddress) - - var wantValid uint64 - if test.valid { - wantValid = 1 - } - - if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { - t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) - } - if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { - t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) - } - }) - } -} - -func TestEnableWhenNICDisabled(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - }{ - { - name: "IPv4", - protocolFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - }, - { - name: "IPv6", - protocolFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var nic testInterface - nic.setEnabled(false) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, - }) - p := s.NetworkProtocolInstance(test.protoNum) - - // We pass nil for all parameters except the NetworkInterface and Stack - // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil) - - // The endpoint should initially be disabled, regardless the NIC's enabled - // status. - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - nic.setEnabled(true) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Attempting to enable the endpoint while the NIC is disabled should - // fail. - nic.setEnabled(false) - err := ep.Enable() - if _, ok := err.(*tcpip.ErrNotPermitted); !ok { - t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) - } - // ep should consider the NIC's enabled status when determining its own - // enabled status so we "enable" the NIC to read just the endpoint's - // enabled status. - nic.setEnabled(true) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Enabling the interface after the NIC has been enabled should succeed. - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - if !ep.Enabled() { - t.Fatal("got ep.Enabled() = false, want = true") - } - - // ep should consider the NIC's enabled status when determining its own - // enabled status. - nic.setEnabled(false) - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - - // Disabling the endpoint when the NIC is enabled should make the endpoint - // disabled. - nic.setEnabled(true) - ep.Disable() - if ep.Enabled() { - t.Fatal("got ep.Enabled() = true, want = false") - } - }) - } -} - -func TestIPv4Send(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, - }, - } - ep := proto.NewEndpoint(&nic, nil) - defer ep.Close() - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Setup the packet buffer. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ep.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - }) - - // Issue the write. - nic.testObject.protocol = 123 - nic.testObject.srcAddr = localIPv4Addr - nic.testObject.dstAddr = remoteIPv4Addr - nic.testObject.contents = payload - - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ - Protocol: 123, - TTL: 123, - TOS: stack.DefaultTOS, - }, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestReceive(t *testing.T) { - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - v4 bool - epAddr tcpip.AddressWithPrefix - handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - v4: true, - epAddr: localIPv4Addr.WithPrefix(), - handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { - const totalLen = header.IPv4MinimumSize + 30 /* payload length */ - - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - TTL: ipv4.DefaultTTL, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - }, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - v4: false, - epAddr: localIPv6Addr.WithPrefix(), - handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { - const payloadLen = 30 - view := buffer.NewView(header.IPv6MinimumSize + payloadLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadLen, - TransportProtocol: 10, - HopLimit: ipv6.DefaultTTL, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, - }) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: test.v4, - }, - } - ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) - } - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) - } else { - ep.DecRef() - } - - stat := s.Stats().IP.PacketsReceived - if got := stat.Value(); got != 0 { - t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) - } - test.handlePacket(t, ep, &nic) - if nic.testObject.dataCalls != 1 { - t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } - if got := stat.Value(); got != 1 { - t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) - } - }) - } -} - -func TestIPv4ReceiveControl(t *testing.T) { - const ( - mtu = 0xbeef - header.IPv4MinimumSize - dataLen = 8 - ) - - cases := []struct { - name string - expectedCount int - fragmentOffset uint16 - code header.ICMPv4Code - transErr transportError - trunc int - }{ - { - name: "FragmentationNeeded", - expectedCount: 1, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4FragmentationNeeded), - info: mtu, - kind: stack.PacketTooBigTransportError, - }, - trunc: 0, - }, - { - name: "Truncated (missing IPv4 header)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, - }, - { - name: "Truncated (partial offending packet's IP header)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, - }, - { - name: "Truncated (partial offending packet's data)", - expectedCount: 0, - fragmentOffset: 0, - code: header.ICMPv4FragmentationNeeded, - trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, - }, - { - name: "Port unreachable", - expectedCount: 1, - fragmentOffset: 0, - code: header.ICMPv4PortUnreachable, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4PortUnreachable), - kind: stack.DestinationPortUnreachableTransportError, - }, - trunc: 0, - }, - { - name: "Non-zero fragment offset", - expectedCount: 0, - fragmentOffset: 100, - code: header.ICMPv4PortUnreachable, - trunc: 0, - }, - { - name: "Zero-length packet", - expectedCount: 0, - fragmentOffset: 100, - code: header.ICMPv4PortUnreachable, - trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + dataLen) - - // Create the outer IPv4 header. - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(view) - c.trunc), - TTL: 20, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Create the ICMP header. - icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) - icmp.SetType(header.ICMPv4DstUnreachable) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) - ip.Encode(&header.IPv4Fields{ - TotalLength: 100, - TTL: 20, - Protocol: 10, - FragmentOffset: c.fragmentOffset, - SrcAddr: localIPv4Addr, - DstAddr: remoteIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - icmp.SetChecksum(0) - checksum := ^header.Checksum(icmp, 0 /* initial */) - icmp.SetChecksum(checksum) - - // Give packet to IPv4 endpoint, dispatcher will validate that - // it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[dataOffset:] - nic.testObject.transErr = c.transErr - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - ep.HandlePacket(pkt) - if want := c.expectedCount; nic.testObject.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) - } - }) - } -} - -func TestIPv4FragmentationReceive(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv4MinimumSize + 24 - - frag1 := buffer.NewView(totalLen) - ip1 := header.IPv4(frag1) - ip1.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 0, - Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip1.SetChecksum(^ip1.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag1[i] = uint8(i) - } - - frag2 := buffer.NewView(totalLen) - ip2 := header.IPv4(frag2) - ip2.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 24, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip2.SetChecksum(^ip2.CalculateChecksum()) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag2[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - - // Send first segment. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: frag1.ToVectorisedView(), - }) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv4Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) - } - - // Send second segment. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: frag2.ToVectorisedView(), - }) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - -func TestIPv6Send(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Setup the packet buffer. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ep.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - }) - - // Issue the write. - nic.testObject.protocol = 123 - nic.testObject.srcAddr = localIPv6Addr - nic.testObject.dstAddr = remoteIPv6Addr - nic.testObject.contents = payload - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ - Protocol: 123, - TTL: 123, - TOS: stack.DefaultTOS, - }, pkt); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv6ReceiveControl(t *testing.T) { - const ( - mtu = 0xffff - outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" - dataLen = 8 - ) - - newUint16 := func(v uint16) *uint16 { return &v } - - portUnreachableTransErr := transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6DstUnreachable), - code: uint8(header.ICMPv6PortUnreachable), - kind: stack.DestinationPortUnreachableTransportError, - } - - cases := []struct { - name string - expectedCount int - fragmentOffset *uint16 - typ header.ICMPv6Type - code header.ICMPv6Code - transErr transportError - trunc int - }{ - { - name: "PacketTooBig", - expectedCount: 1, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6PacketTooBig), - code: uint8(header.ICMPv6UnusedCode), - info: mtu, - kind: stack.PacketTooBigTransportError, - }, - trunc: 0, - }, - { - name: "Truncated (missing offending packet's IPv6 header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, - }, - { - name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, - }, - { - name: "Truncated (partial offending packet's data)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6PacketTooBig, - code: header.ICMPv6UnusedCode, - trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, - }, - { - name: "Port unreachable", - expectedCount: 1, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "Truncated DstPortUnreachable (partial offending packet's IP header)", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, - }, - { - name: "DstPortUnreachable for Fragmented, zero offset", - expectedCount: 1, - fragmentOffset: newUint16(0), - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "DstPortUnreachable for Non-zero fragment offset", - expectedCount: 0, - fragmentOffset: newUint16(100), - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - transErr: portUnreachableTransErr, - trunc: 0, - }, - { - name: "Zero-length packet", - expectedCount: 0, - fragmentOffset: nil, - typ: header.ICMPv6DstUnreachable, - code: header.ICMPv6PortUnreachable, - trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize - if c.fragmentOffset != nil { - dataOffset += header.IPv6FragmentHeaderSize - } - view := buffer.NewView(dataOffset + dataLen) - - // Create the outer IPv6 header. - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) - icmp.SetType(c.typ) - icmp.SetCode(c.code) - icmp.SetIdent(0xdead) - icmp.SetSequence(0xbeef) - - var extHdrs header.IPv6ExtHdrSerializer - // Build the fragmentation header if needed. - if c.fragmentOffset != nil { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: *c.fragmentOffset, - M: true, - Identification: 0x12345678, - }) - } - - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - ExtensionHeaders: extHdrs, - }) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv6 endpoint, dispatcher will validate that - // it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[dataOffset:] - nic.testObject.transErr = c.transErr - - // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: outerSrcAddr, - Dst: localIPv6Addr, - })) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") - } - addr := localIPv6Addr.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - ep.HandlePacket(pkt) - if want := c.expectedCount; nic.testObject.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) - } - }) - } -} - -// truncatedPacket returns a PacketBuffer based on a truncated view. If view, -// after truncation, is large enough to hold a network header, it makes part of -// view the packet's NetworkHeader and the rest its Data. Otherwise all of view -// becomes Data. -func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { - v := view[:len(view)-trunc] - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - }) - return pkt -} - -func TestWriteHeaderIncludedPacket(t *testing.T) { - const ( - nicID = 1 - transportProto = 5 - - dataLen = 4 - ) - - dataBuf := [dataLen]byte{1, 2, 3, 4} - data := dataBuf[:] - - ipv4Options := header.IPv4OptionsSerializer{ - &header.IPv4SerializableListEndOption{}, - &header.IPv4SerializableNOPOption{}, - &header.IPv4SerializableListEndOption{}, - &header.IPv4SerializableNOPOption{}, - } - - expectOptions := header.IPv4Options{ - byte(header.IPv4OptionListEndType), - byte(header.IPv4OptionNOPType), - byte(header.IPv4OptionListEndType), - byte(header.IPv4OptionNOPType), - } - - ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} - ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] - - var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte - ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] - if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) - } - if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - nicAddr tcpip.Address - remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView - checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) - expectedErr tcpip.Error - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv4MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv4MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv4 with IHL too small", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv4MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - ip.SetHeaderLength(header.IPv4MinimumSize - 1) - return hdr.View().ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - { - name: "IPv4 too small", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip[:len(ip)-1]).ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - { - name: "IPv4 minimum size", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip).ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv4MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(header.IPv4MinimumSize), - checker.IPPayload(nil), - ) - }, - }, - { - name: "IPv4 with options", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - totalLen := ipHdrLen + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv4(hdr.Prepend(ipHdrLen)) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - Options: ipv4Options, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - if len(netHdr.View()) != hdrLen { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(hdrLen), - checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(expectOptions), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv4 with options and data across views", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - nicAddr: localIPv4Addr, - remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) - ip.Encode(&header.IPv4Fields{ - Protocol: transportProto, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - Options: ipv4Options, - }) - vv := buffer.View(ip).ToVectorisedView() - vv.AppendView(data) - return vv - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv4Any { - src = localIPv4Addr - } - - netHdr := pkt.NetworkHeader() - - hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) - if len(netHdr.View()) != hdrLen { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) - } - - checker.IPv4(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(hdrLen), - checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(expectOptions), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv6MinimumSize + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv6MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), - checker.IPPayload(data), - ) - }, - }, - { - name: "IPv6 with extension header", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) - hdr := buffer.NewPrependable(totalLen) - if n := copy(hdr.Prepend(len(data)), data); n != len(data) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) - } - if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - // NB: we're lying about transport protocol here to verify the raw - // fragment header bytes. - TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return hdr.View().ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), - checker.IPPayload(ipv6PayloadWithExtHdr), - ) - }, - }, - { - name: "IPv6 minimum size", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip).ToVectorisedView() - }, - checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { - if src == header.IPv6Any { - src = localIPv6Addr - } - - netHdr := pkt.NetworkHeader() - - if len(netHdr.View()) != header.IPv6MinimumSize { - t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) - } - - checker.IPv6(t, stack.PayloadSince(netHdr), - checker.SrcAddr(src), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(header.IPv6MinimumSize), - checker.IPPayload(nil), - ) - }, - }, - { - name: "IPv6 too small", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - nicAddr: localIPv6Addr, - remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - TransportProtocol: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, - }) - return buffer.View(ip[:len(ip)-1]).ToVectorisedView() - }, - expectedErr: &tcpip.ErrMalformedHeader{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - subTests := []struct { - name string - srcAddr tcpip.Address - }{ - { - name: "unspecified source", - srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), - }, - { - name: "random source", - srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), - }, - } - - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) - - r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) - } - defer r.Release() - - { - err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr), - })) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) - } - } - - if test.expectedErr != nil { - return - } - - pkt, ok := e.Read() - if !ok { - t.Fatal("expected a packet to be written") - } - test.checker(t, pkt.Pkt, subTest.srcAddr) - }) - } - }) - } -} - -// Test that the included data in an ICMP error packet conforms to the -// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 -func TestICMPInclusionSize(t *testing.T) { - const ( - replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize - replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize - targetSize4 = header.IPv4MinimumProcessableDatagramSize - targetSize6 = header.IPv6MinimumMTU - // A protocol number that will cause an error response. - reservedProtocol = 254 - ) - - // IPv4 function to create a IP packet and send it to the stack. - // The packet should generate an error response. We can do that by using an - // unknown transport protocol (254). - rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { - totalLen := header.IPv4MinimumSize + len(payload) - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: reservedProtocol, - TTL: ipv4.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - vv := hdr.View().ToVectorisedView() - vv.AppendView(buffer.View(payload)) - // Take a copy before InjectInbound takes ownership of vv - // as vv may be changed during the call. - v := vv.ToView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return v - } - - // IPv6 function to create a packet and send it to the stack. - // The packet should be errant in a way that causes the stack to send an - // ICMP error response and have enough data to allow the testing of the - // inclusion of the errant packet. Use `unknown next header' to generate - // the error. - rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload)), - TransportProtocol: reservedProtocol, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, - }) - vv := hdr.View().ToVectorisedView() - vv.AppendView(buffer.View(payload)) - // Take a copy before InjectInbound takes ownership of vv - // as vv may be changed during the call. - v := vv.ToView() - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return v - } - - v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { - // We already know the entire packet is the right size so we can use its - // length to calculate the right payload size to check. - expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize - checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), - checker.SrcAddr(localIPv4Addr), - checker.DstAddr(remoteIPv4Addr), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), - checker.ICMPv4Payload(payload[:expectedPayloadLength]), - ), - ) - } - - v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { - // We already know the entire packet is the right size so we can use its - // length to calculate the right payload size to check. - expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize - checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), - checker.SrcAddr(localIPv6Addr), - checker.DstAddr(remoteIPv6Addr), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6ParamProblem), - checker.ICMPv6Code(header.ICMPv6UnknownHeader), - checker.ICMPv6Payload(payload[:expectedPayloadLength]), - ), - ) - } - tests := []struct { - name string - srcAddress tcpip.Address - injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View - checker func(*testing.T, *stack.PacketBuffer, buffer.View) - payloadLength int // Not including IP header. - linkMTU uint32 // Largest IP packet that the link can send as payload. - replyLength int // Total size of IP/ICMP packet expected back. - }{ - { - name: "IPv4 exact match", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 - replyHeaderLength4, - linkMTU: targetSize4, - replyLength: targetSize4, - }, - { - name: "IPv4 larger MTU", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4, - linkMTU: targetSize4 + 1000, - replyLength: targetSize4, - }, - { - name: "IPv4 smaller MTU", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4, - linkMTU: targetSize4 - 50, - replyLength: targetSize4 - 50, - }, - { - name: "IPv4 payload exceeds", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 + 10, - linkMTU: targetSize4, - replyLength: targetSize4, - }, - { - name: "IPv4 1 byte less", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: targetSize4 - replyHeaderLength4 - 1, - linkMTU: targetSize4, - replyLength: targetSize4 - 1, - }, - { - name: "IPv4 No payload", - srcAddress: remoteIPv4Addr, - injector: rxIPv4Bad, - checker: v4Checker, - payloadLength: 0, - linkMTU: targetSize4, - replyLength: replyHeaderLength4, - }, - { - name: "IPv6 exact match", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6 - replyHeaderLength6, - linkMTU: targetSize6, - replyLength: targetSize6, - }, - { - name: "IPv6 larger MTU", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6, - linkMTU: targetSize6 + 400, - replyLength: targetSize6, - }, - // NB. No "smaller MTU" test here as less than 1280 is not permitted - // in IPv6. - { - name: "IPv6 payload exceeds", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6, - linkMTU: targetSize6, - replyLength: targetSize6, - }, - { - name: "IPv6 1 byte less", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: targetSize6 - replyHeaderLength6 - 1, - linkMTU: targetSize6, - replyLength: targetSize6 - 1, - }, - { - name: "IPv6 no payload", - srcAddress: remoteIPv6Addr, - injector: rxIPv6Bad, - checker: v6Checker, - payloadLength: 0, - linkMTU: targetSize6, - replyLength: replyHeaderLength6, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) - // Allocate and initialize the payload view. - payload := buffer.NewView(test.payloadLength) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - // Default routes for IPv4&6 so ICMP can find a route to the remote - // node when attempting to send the ICMP error Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - v := test.injector(e, test.srcAddress, payload) - pkt, ok := e.Read() - if !ok { - t.Fatal("expected a packet to be written") - } - if got, want := pkt.Pkt.Size(), test.replyLength; got != want { - t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) - } - test.checker(t, pkt.Pkt, v) - }) - } -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index 5e7f10f4b..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "igmp.go", - "ipv4.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = [ - "igmp_test.go", - "ipv4_test.go", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":ipv4", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go deleted file mode 100644 index e5e1b89cc..000000000 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - multicastAddr = tcpip.Address("\xe0\x00\x00\x03") - nicID = 1 - defaultTTL = 1 - defaultPrefixLength = 24 -) - -// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. Raises a t.Error if -// any field does not match. -func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(srcAddr), - checker.DstAddr(dstAddr), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(igmpType), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - // Create an endpoint of queue size 1, since no more than 1 packets are ever - // queued in the tests in this file. - e := channel.New(1, 1280, linkAddr) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - return e, s, clock -} - -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) { - var options header.IPv4OptionsSerializer - if hasRouterAlertOption { - options = header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: ttl, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: srcAddr, - DstAddr: dstAddr, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(igmpType) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 -// router is detected. V1 present status is expected to be reset when the NIC -// cycles. -func TestIGMPV1Present(t *testing.T) { - e, s, clock := createStack(t, true) - addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength} - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) - } - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // This NIC will send an IGMPv2 report immediately, before this test can get - // the IGMPv1 General Membership Query in. - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Inject an IGMPv1 General Membership Query which is identical to a standard - // membership query except the Max Response Time is set to 0, which will tell - // the stack that this is a router using IGMPv1. Send it to the all systems - // group which is the only group this host belongs to. - createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */) - if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { - t.Fatalf("got Membership Queries received = %d, want = 1", got) - } - - // Before advancing the clock, verify that this host has not sent a - // V1MembershipReport yet. - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) - } - - // Verify the solicited Membership Report is sent. Now that this NIC has seen - // an IGMPv1 query, it should send an IGMPv1 Membership Report. - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - - // Cycling the interface should reset the V1 present flag. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } -} - -func TestSendQueuedIGMPReports(t *testing.T) { - e, s, clock := createStack(t, true) - - // Joining a group without an assigned address should queue IGMP packets; none - // should be sent without an assigned address. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) - } - reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport - if got := reportStat.Value(); got != 0 { - t.Errorf("got reportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } - - // The initial set of IGMP reports that were queued should be sent once an - // address is assigned. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) - } - if got := reportStat.Value(); got != 1 { - t.Errorf("got reportStat.Value() = %d, want = 1", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - if got := reportStat.Value(); got != 2 { - t.Errorf("got reportStat.Value() = %d, want = 2", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Should have no more packets to send after the initial set of unsolicited - // reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } -} - -func TestIGMPPacketValidation(t *testing.T) { - tests := []struct { - name string - messageType header.IGMPType - stackAddresses []tcpip.AddressWithPrefix - srcAddr tcpip.Address - includeRouterAlertOption bool - ttl uint8 - expectValidIGMP bool - getMessageTypeStatValue func(tcpip.Stats) uint64 - }{ - { - name: "valid", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "bad ttl", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 2, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "missing router alert ip option", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: false, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "igmp leave group and src ip does not belong to nic subnet", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "igmp query and src ip does not belong to nic subnet", - messageType: header.IGMPMembershipQuery, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, - }, - { - name: "igmp report v1 and src ip does not belong to nic subnet", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "igmp report v2 and src ip does not belong to nic subnet", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "src ip belongs to the subnet of the nic's second address", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{ - {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24}, - {Address: stackAddr, PrefixLen: 24}, - }, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) - for _, address := range test.stackAddresses { - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err) - } - } - stats := s.Stats() - // Verify that every relevant stats is zero'd before we send a packet. - if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got) - } - if got := stats.IP.PacketsDelivered.Value(); got != 0 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) - } - createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption) - // We always expect the packet to pass IP validation. - if got := stats.IP.PacketsDelivered.Value(); got != 1 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) - } - // Even when the IGMP-specific validation checks fail, we expect the - // corresponding IGMP counter to be incremented. - if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) - } - var expectedInvalidCount uint64 - if !test.expectValidIGMP { - expectedInvalidCount = 1 - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100644 index 000000000..87a48e2ce --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,105 @@ +// automatically generated by stateify. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError" +} + +func (i *icmpv4DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv4DestinationUnreachableSockError) beforeSave() {} + +func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv4DestinationUnreachableSockError) afterLoad() {} + +func (i *icmpv4DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError" +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {} + +func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationHostUnreachableSockError) afterLoad() {} + +func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError" +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {} + +func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) afterLoad() {} + +func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (e *icmpv4FragmentationNeededSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError" +} + +func (e *icmpv4FragmentationNeededSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + "mtu", + } +} + +func (e *icmpv4FragmentationNeededSockError) beforeSave() {} + +func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError) + stateSinkObject.Save(1, &e.mtu) +} + +func (e *icmpv4FragmentationNeededSockError) afterLoad() {} + +func (e *icmpv4FragmentationNeededSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError) + stateSourceObject.Load(1, &e.mtu) +} + +func init() { + state.Register((*icmpv4DestinationUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationHostUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv4FragmentationNeededSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index cfed241bf..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,3106 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/raw" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - extraHeaderReserve = 50 - defaultMTU = 65536 -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) - if testing.Verbose() { - ep = sniffer.New(ep) - } - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - { - err := ep.Connect(randomAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) - } - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -func TestForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - randomSequence = 123 - randomIdent = 42 - randomTimeOffset = 0x10203040 - ) - - ipv4Addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - } - ipv4Addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), - PrefixLen: 8, - } - remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) - remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) - - tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code - }{ - { - name: "TTL of zero", - TTL: 0, - expectErrorICMP: true, - icmpType: header.ICMPv4TimeExceeded, - icmpCode: header.ICMPv4TTLExceeded, - }, - { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, - }, - { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, - }, - { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "TS type 1 full", - TTL: 2, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - expectErrorICMP: true, - icmpType: header.ICMPv4ParamProblem, - icmpCode: header.ICMPv4UnusedCode, - }, - { - name: "TS type 0", - TTL: 2, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - forwardedOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - name: "end of options list", - TTL: 2, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - 1, 2, 3, 4, - }, - forwardedOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 7 bytes unknown option removed. - 0, 0, 0, 0, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} - if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) - } - - e2 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} - if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: ipv4Addr1.Subnet(), - NIC: nicID1, - }, - { - Destination: ipv4Addr2.Subnet(), - NIC: nicID2, - }, - }) - - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) - } - - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - Protocol: uint8(header.ICMPv4ProtocolNumber), - TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, - }) - if len(test.options) != 0 { - ip.SetHeaderLength(uint8(ipHeaderLength)) - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - - if test.expectErrorICMP { - reply, ok := e1.Read() - if !ok { - t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) - } - - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), - checker.TTL(ipv4.DefaultTTL), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(test.icmpType), - checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View())), - ), - ) - - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) - } - } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") - } - - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) - - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } - } - }) - } -} - -// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and -// checks the response. -func TestIPv4Sanity(t *testing.T) { - const ( - ttl = 255 - nicID = 1 - randomSequence = 123 - randomIdent = 42 - // In some cases Linux sets the error pointer to the start of the option - // (offset 0) instead of the actual wrong value, which is the length byte - // (offset 1). For compatibility we must do the same. Use this constant - // to indicate where this happens. - pointerOffsetForInvalidLength = 0 - randomTimeOffset = 0x10203040 - ) - var ( - ipv4Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - } - remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) - ) - - tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - options header.IPv4Options - replyOptions header.IPv4Options // reply should look like this - shouldFail bool - expectErrorICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - paramProblemPointer uint8 - }{ - { - name: "valid no options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - }, - { - name: "bad header checksum", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - badHeaderChecksum: true, - shouldFail: true, - }, - // The TTL tests check that we are not rejecting an incoming packet - // with a zero or one TTL, which has been a point of confusion in the - // past as RFC 791 says: "If this field contains the value zero, then the - // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies - // for the case of the destination host, stating as follows. - // - // A host MUST NOT send a datagram with a Time-to-Live (TTL) - // value of zero. - // - // A host MUST NOT discard a datagram just because it was - // received with TTL less than 2. - { - name: "zero TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 0, - }, - { - name: "one TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 1, - }, - { - name: "End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{0, 0, 0, 0}, - replyOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "NOP options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 1, 1}, - replyOptions: header.IPv4Options{1, 1, 1, 1}, - }, - { - name: "NOP and End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 0, 0}, - replyOptions: header.IPv4Options{1, 1, 0, 0}, - }, - { - name: "bad header length", - headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (0)", - maxTotalLength: 0, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip + icmp - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad protocol", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: 99, - TTL: ttl, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4DstUnreachable, - ICMPCode: header.ICMPv4ProtoUnreachable, - }, - { - name: "timestamp option overflow", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - }, - { - name: "timestamp option overflow full", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - // ^ Counter full (15/0xF) - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - replyOptions: header.IPv4Options{}, - }, - { - name: "unknown option", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{10, 4, 9, 0}, - // ^^ - // The unknown option should be stripped out of the reply. - replyOptions: header.IPv4Options{}, - }, - { - name: "bad option - no length", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 1, 1, 1, 68, - // ^-start of timestamp.. but no length.. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - }, - { - name: "bad option - length 0", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 0, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length big", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 9, 9, 0, - // ^ - // There are only 8 bytes allocated to options so 9 bytes of timestamp - // space is not possible. (Second byte) - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - // This tests for some linux compatible behaviour. - // The ICMP pointer returned is 22 for Linux but the - // error is actually in spot 21. - name: "bad option - length bad", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - // Timestamps are in multiples of 4 or 8 but never 7. - // The option space should be padded out. - options: header.IPv4Options{ - 68, 7, 5, 0, - // ^ ^ Linux points here which is wrong. - // | Not a multiple of 4 - 1, 2, 3, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "multiple type 0 with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // The timestamp area is full so add to the overflow count. - name: "multiple type 1 timestamps", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 21, 0x11, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - // Overflow count is the top nibble of the 4th byte. - replyOptions: header.IPv4Options{ - 68, 20, 21, 0x21, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - }, - { - name: "multiple type 1 timestamps with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 28, 21, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 28, 29, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 192, 168, 1, 58, // New IP Address. - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Timestamp pointer uses one based counting so 0 is invalid. - name: "timestamp pointer invalid", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, 0, 0x00, - // ^ 0 instead of 5 or more. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - }, - { - // Timestamp pointer cannot be less than 5. It must point past the header - // which is 4 bytes. (1 based counting) - name: "timestamp pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, - // ^ header is 4 bytes, so 4 should fail. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "valid timestamp pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, - // ^ header is 4 bytes, so 5 should succeed. - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 8, 9, 0x00, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Needs 8 bytes for a type 1 timestamp but there are only 4 free. - name: "bad timer element alignment", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 17, 0x01, - // ^^ ^^ 20 byte area, next free spot at 17. - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - // End of option list with illegal option after it, which should be ignored. - { - name: "end of options list", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 3 bytes unknown option removed. - }, - }, - { - // Timestamp with a size much too small. - name: "timestamp truncated", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 0, 0, - // ^ Smallest possible is 8. Linux points at the 68. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "single record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 4, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "multiple record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 20, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "single record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Unlike timestamp, this should just succeed. - name: "multiple record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 24, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Pointer uses one based counting so 0 is invalid. - name: "record route pointer zero", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 0, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. 3 should fail. - name: "record route pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. Check 4 passes. (Duplicates "single - // record route with room") - name: "valid record route pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Confirm Linux bug for bug compatibility. - // Linux returns slot 22 but the error is in slot 21. - name: "multiple record route with not enough room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 8, // 3 byte header - // ^ ^ Linux points here. We must too. - // | Not enough room. 1 byte free, need 4. - 1, 2, 3, 4, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - name: "duplicate record route", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, 0, // pad - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 7, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) - } - - // Default routes for IPv4 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - if len(test.options)%4 != 0 { - t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) - } - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - - // Specify ident/seq to make sure we get the same in the response. - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - if test.maxTotalLength < totalLen { - totalLen = test.maxTotalLength - } - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - Protocol: test.transportProtocol, - TTL: test.TTL, - SrcAddr: remoteIPv4Addr, - DstAddr: ipv4Addr.Address, - }) - if test.headerLength != 0 { - ip.SetHeaderLength(test.headerLength) - } else { - // Set the calculated header length, since we may manually add options. - ip.SetHeaderLength(uint8(ipHeaderLength)) - } - if len(test.options) != 0 { - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ipHeaderChecksum := ip.CalculateChecksum() - if test.badHeaderChecksum { - ipHeaderChecksum += 42 - } - ip.SetChecksum(^ipHeaderChecksum) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - reply, ok := e.Read() - if !ok { - if test.shouldFail { - if test.expectErrorICMP { - t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) - } - return // Expected silent failure. - } - t.Fatal("expected ICMP echo reply missing") - } - - // We didn't expect a packet. Register our surprise but carry on to - // provide more information about what we got. - if test.shouldFail && !test.expectErrorICMP { - t.Error("unexpected packet response") - } - - // Check the route that brought the packet to us. - if reply.Route.LocalAddress != ipv4Addr.Address { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) - } - if reply.Route.RemoteAddress != remoteIPv4Addr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) - } - - // Make sure it's all in one buffer for checker. - replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - - // At this stage we only know it's probably an IP+ICMP header so verify - // that much. - checker.IPv4(t, replyIPHeader, - checker.SrcAddr(ipv4Addr.Address), - checker.DstAddr(remoteIPv4Addr), - checker.ICMPv4( - checker.ICMPv4Checksum(), - ), - ) - - // Don't proceed any further if the checker found problems. - if t.Failed() { - t.FailNow() - } - - // OK it's ICMP. We can safely look at the type now. - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - switch replyICMPHeader.Type() { - case header.ICMPv4ParamProblem: - if !test.shouldFail { - t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) - } - if !test.expectErrorICMP { - t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), - ), - ) - return - case header.ICMPv4DstUnreachable: - if !test.shouldFail { - t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - if !test.expectErrorICMP { - t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), - ), - ) - return - case header.ICMPv4EchoReply: - if test.shouldFail { - if !test.expectErrorICMP { - t.Error("got Echo Reply packet, want no response") - } else { - t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) - } - } - // If the IP options change size then the packet will change size, so - // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - len(test.options) - - checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength+sizeChange), - checker.IPv4Options(test.replyOptions), - checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Seq(randomSequence), - checker.ICMPv4Ident(randomIdent), - ), - ) - default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) - } - }) - } -} - -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { - // Make a complete array of the sourcePacket packet. - source := header.IPv4(packets[0].NetworkHeader().View()) - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - // Build up an array of the bytes sent. - var reassembledPayload buffer.VectorisedView - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - fragmentIPHeader := header.IPv4(allBytes.ToView()) - if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) - } - if got := len(fragmentIPHeader); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) - } - if got := fragmentIPHeader.TransportProtocol(); got != proto { - return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) - } - if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } - if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) - } - if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { - return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) - } - if wantFragments[i].more { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) - } - reassembledPayload.AppendView(packet.TransportHeader().View()) - reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView()) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - } - - expected := buffer.View(source[source.HeaderLength():]) - if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - gso *stack.GSO - transportHeaderLength int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: 1280, - gso: nil, - transportHeaderLength: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: 1280, - gso: nil, - transportHeaderLength: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 744, more: false}, - }, - }, - { - description: "Fragmented with the minimum mtu", - mtu: header.IPv4MinimumMTU, - gso: nil, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv4MinimumMTU + 1, - gso: nil, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - gso: nil, - transportHeaderLength: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with gso none", - mtu: 1280, - gso: &stack.GSO{Type: stack.GSONone}, - transportHeaderLength: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 144, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: 1280, - gso: nil, - transportHeaderLength: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 44, more: false}, - }, - }, - { - description: "Fragmented with MTU smaller than header", - mtu: 300, - gso: nil, - transportHeaderLength: 1000, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 280, more: true}, - {offset: 280, payloadSize: 280, more: true}, - {offset: 560, payloadSize: 280, more: true}, - {offset: 840, payloadSize: 280, more: true}, - {offset: 1120, payloadSize: 280, more: true}, - {offset: 1400, payloadSize: 100, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ttl = 42 - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - source := pkt.Clone() - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("r.WritePacket(_, _, _) = %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - writePacketsTests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) - - for _, test := range writePacketsTests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if err != nil { - t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) - } - if n != wantTotalPackets { - t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transportHeaderLength int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag MTU smaller than header", - mtu: 500, - transportHeaderLength: 1000, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 4, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than IPv4 minimum MTU", - mtu: header.IPv4MinimumMTU - 1, - transportHeaderLength: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestInvalidFragments(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" - tos = 0 - ident = 1 - ttl = 48 - protocol = 6 - ) - - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload - } - - type fragmentData struct { - ipv4fields header.IPv4Fields - // 0 means insert the correct IHL. Non 0 means override the correct IHL. - overrideIHL int // For 0 use 1 as it is an int and will be divided by 4. - payload []byte - autoChecksum bool // If true, the Checksum field will be overwritten. - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - }{ - { - name: "IHL and TotalLength zero, FragmentOffset non-zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, - FragmentOffset: 59776, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - name: "IHL and TotalLength zero, FragmentOffset zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - // Payload 17 octets and Fragment offset 65520 - // Leading to the fragment end to be past 65536. - name: "fragment ends past 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 17, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(17), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - // Payload 16 octets and fragment offset 65520 - // Leading to the fragment end to be exactly 65536. - name: "fragment ends exactly at 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(16), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 0, - wantMalformedFragments: 0, - }, - { - name: "IHL less than IPv4 minimum size", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 1944, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize - 12, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 2, - wantMalformedFragments: 0, - }, - { - name: "fragment with short TotalLength and extra payload", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 28816, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 4, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - name: "multiple fragments with More Fragments flag set to false", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 128, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - }) - e := channel.New(0, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize + len(f.payload) - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got { - t.Fatalf("copied %d bytes, expected %d bytes.", got, want) - } - // Encode sets this up correctly. If we want a different value for - // testing then we need to overwrite the good value. - if f.overrideIHL != 0 { - ip.SetHeaderLength(uint8(f.overrideIHL)) - // If we are asked to add options (type not specified) then pad - // with 0 (EOL). RFC 791 page 23 says "The padding is zero". - for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ { - ip[i] = byte(header.IPv4OptionListEndType) - } - } - - if f.autoChecksum { - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - } - - vv := hdr.View().ToVectorisedView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) - } - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = "\x0a\x00\x00\x01" - addr2 = "\x0a\x00\x00\x02" - tos = 0 - ident = 1 - ttl = 48 - protocol = 99 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv4fields header.IPv4Fields - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - Clock: clock, - }) - e := channel.New(1, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && ip.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(header.IPv4ProtocolNumber, pkt) - } - - clock.Advance(ipv4.ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4TimeExceeded), - checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), - checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), - ), - ) - }) - } -} - -// TestReceiveFragments feeds fragments in through the incoming packet path to -// test reassembly -func TestReceiveFragments(t *testing.T) { - const ( - nicID = 1 - - addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 - ) - - // Build and return a UDP header containing payload. - udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { - payload := buffer.NewView(payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + len(payload) - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - // UDP header plus a payload of 0..256 - ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) - udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] - ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) - udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 2. - ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) - udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 3. - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 791 section 3.1 page 14). - ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) - udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled IPv4 payload length. - ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) - udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] - - type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - id uint16 - flags uint8 - fragmentOffset uint16 - payload buffer.View - } - - tests := []struct { - name string - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "No fragmentation with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "More fragments without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Non-zero fragment offset without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 8, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload3Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:63], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 63, - payload: ipv4Payload3Addr1ToAddr2[63:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Second fragment has MoreFlags set", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload2Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload2Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr3ToAddr2[:32], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 32, - payload: ipv4Payload1Addr3ToAddr2[32:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - { - name: "Fragment without followup", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Setup a stack and endpoint. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - RawFactory: raw.EndpointFactory{}, - }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - // Bring up a raw endpoint so we can examine network headers. - epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) - if err != nil { - t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer epRaw.Close() - - // Prepare and send the fragments. - for _, frag := range test.fragments { - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - - // Serialize IPv4 fixed header. - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), - ID: frag.id, - Flags: frag.flags, - FragmentOffset: frag.fragmentOffset, - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: frag.srcAddr, - DstAddr: frag.dstAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(frag.payload) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, expectedPayload := range test.expectedPayloads { - // Check UDP payload delivered by UDP endpoint. - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) ep.Read: %s", i, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(expectedPayload), - Total: len(expectedPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) - } - if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) - } - - // Check IPv4 header in packet delivered by raw endpoint. - buf.Reset() - result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) epRaw.Read: %s", i, err) - } - // Reassambly does not take care of checksum. Here we write our own - // check routine instead of using checker.IPv4. - ip := header.IPv4(buf.Bytes()) - for _, check := range []checker.NetworkChecker{ - checker.FragmentFlags(0), - checker.FragmentOffset(0), - checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), - } { - check(t, []header.Network{ip}) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, - }, - } - - // Parameterize the tests to run with both WritePacket and WritePackets. - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) - } - if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) - } - { - mask := tcpip.AddressMask(header.IPv4Broadcast) - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func TestPacketQueing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 8, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, - TTL: ipv4.DefaultTTL, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply), - checker.ICMPv4Code(header.ICMPv4UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(1, defaultMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a ARP request since link address resolution should be - // performed. - { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != arp.ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) - } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - rep := header.ARP(p.Pkt.NetworkHeader().View()) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) - } - } - - // Send an ARP reply to complete link address resolution. - { - hdr := buffer.View(make([]byte, header.ARPSize)) - packet := header.ARP(hdr) - packet.SetIPv4OverEthernet() - packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), host2NICLinkAddr) - copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) - copy(packet.HardwareAddressTarget(), host1NICLinkAddr) - copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) - e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} - -// TestCloseLocking test that lock ordering is followed when closing an -// endpoint. -func TestCloseLocking(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - src = tcpip.Address("\x10\x00\x00\x01") - dst = tcpip.Address("\x10\x00\x00\x02") - - iterations = 1000 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - // Perform NAT so that the endoint tries to search for a sibling endpoint - // which ends up taking the protocol and endpoint lock (in that order). - table := stack.Table{ - Rules: []stack.Rule{ - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 1, - stack.Forward: stack.HookUnset, - stack.Output: 2, - stack.Postrouting: 3, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 1, - stack.Forward: stack.HookUnset, - stack.Output: 2, - stack.Postrouting: 3, - }, - } - if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { - t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) - } - - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }}) - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} - if err := ep.Connect(addr); err != nil { - t.Errorf("ep.Connect(%#v): %s", addr, err) - } - - var wg sync.WaitGroup - defer wg.Wait() - - // Writing packets should trigger NAT which requires the stack to search the - // protocol for network endpoints with the destination address. - // - // Creating and removing interfaces should modify the protocol and endpoint - // which requires taking the locks of each. - // - // We expect the protocol > endpoint lock ordering to be followed here. - wg.Add(2) - go func() { - defer wg.Done() - - data := []byte{1, 2, 3, 4} - - for i := 0; i < iterations; i++ { - var r bytes.Reader - r.Reset(data) - if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Errorf("ep.Write(_, _): %s", err) - return - } else if want := int64(len(data)); n != want { - t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) - return - } - } - }() - go func() { - defer wg.Done() - - for i := 0; i < iterations; i++ { - if err := s.CreateNIC(nicID2, loopback.New()); err != nil { - t.Errorf("CreateNIC(%d, _): %s", nicID2, err) - return - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Errorf("RemoveNIC(%d): %s", nicID2, err) - return - } - } - }() -} diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go deleted file mode 100644 index a637f9d50..000000000 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4 - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEP := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD deleted file mode 100644 index bb9a02ed0..000000000 --- a/pkg/tcpip/network/ipv6/BUILD +++ /dev/null @@ -1,70 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv6", - srcs = [ - "dhcpv6configurationfromndpra_string.go", - "icmp.go", - "ipv6.go", - "mld.go", - "ndp.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv6_test", - size = "small", - srcs = [ - "icmp_test.go", - "ipv6_test.go", - "ndp_test.go", - ], - library = ":ipv6", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "ipv6_x_test", - size = "small", - srcs = ["mld_test.go"], - deps = [ - ":ipv6", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go deleted file mode 100644 index d4e63710c..000000000 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ /dev/null @@ -1,1712 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "context" - "net" - "reflect" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - nicID = 1 - - linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - - arbitraryHopLimit = 42 -) - -var ( - lladdr0 = header.LinkLocalAddr(linkAddr0) - lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) -) - -type stubLinkEndpoint struct { - stack.LinkEndpoint -} - -func (*stubLinkEndpoint) MTU() uint32 { - return defaultMTU -} - -func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - // Indicate that resolution for link layer addresses is required to send - // packets over this link. This is needed so the NIC knows to allocate a - // neighbor table. - return stack.CapabilityResolutionRequired -} - -func (*stubLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - return nil -} - -func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -type stubDispatcher struct { - stack.TransportDispatcher -} - -func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { - return stack.TransportPacketHandled -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.LinkEndpoint - - probeCount int - confirmationCount int - - nicID tcpip.NICID -} - -func (*testInterface) ID() tcpip.NICID { - return nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (*testInterface) Spoofing() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) -} - -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) -} - -func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { - t.probeCount++ - return nil -} - -func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { - t.confirmationCount++ - return nil -} - -func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { - return false -} - -func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6, hopLimit uint8, includeRouterAlert bool) { - var extensionHeaders header.IPv6ExtHdrSerializer - if includeRouterAlert { - extensionHeaders = header.IPv6ExtHdrSerializer{ - header.IPv6SerializableHopByHopExtHdr{ - &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, - }, - } - } - ip := buffer.NewView(header.IPv6MinimumSize + extensionHeaders.Length()) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: src, - DstAddr: dst, - ExtensionHeaders: extensionHeaders, - }) - - vv := ip.ToVectorisedView() - vv.AppendView(buffer.View(icmp)) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) -} - -func TestICMPCounts(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - hopLimit uint8 - includeRouterAlert bool - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - hopLimit: arbitraryHopLimit, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, - { - typ: header.ICMPv6NeighborAdvert, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - hopLimit: header.NDPHopLimit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6MulticastListenerQuery, - hopLimit: header.MLDHopLimit, - includeRouterAlert: true, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerReport, - hopLimit: header.MLDHopLimit, - includeRouterAlert: true, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerDone, - hopLimit: header.MLDHopLimit, - includeRouterAlert: true, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: 255, /* Unrecognized */ - size: 50, - }, - } - - for _, typ := range types { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], - Src: lladdr0, - Dst: lladdr1, - PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */), - PayloadLen: len(typ.extraData), - })) - handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)), arbitraryHopLimit, false) - - icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } -} - -func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { - t := v.Type() - for i := 0; i < v.NumField(); i++ { - v := v.Field(i) - if s, ok := v.Interface().(*tcpip.StatCounter); ok { - f(t.Field(i).Name, s) - } else { - visitStats(v, f) - } - } -} - -type testContext struct { - s0 *stack.Stack - s1 *stack.Stack - - linkEP0 *channel.Endpoint - linkEP1 *channel.Endpoint -} - -type endpointWithResolutionCapability struct { - stack.LinkEndpoint -} - -func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { - return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired -} - -func newTestContext(t *testing.T) *testContext { - c := &testContext{ - s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }), - s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }), - } - - c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) - - wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) - if testing.Verbose() { - wrappedEP0 = sniffer.New(wrappedEP0) - } - if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { - t.Fatalf("CreateNIC s0: %v", err) - } - if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) - } - - c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) - wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) - } - - subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - c.s0.SetRouteTable( - []tcpip.Route{{ - Destination: subnet0, - NIC: nicID, - }}, - ) - subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) - if err != nil { - t.Fatal(err) - } - c.s1.SetRouteTable( - []tcpip.Route{{ - Destination: subnet1, - NIC: nicID, - }}, - ) - - t.Cleanup(func() { - if err := c.s0.RemoveNIC(nicID); err != nil { - t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err) - } - if err := c.s1.RemoveNIC(nicID); err != nil { - t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err) - } - - c.linkEP0.Close() - c.linkEP1.Close() - }) - - return c -} - -type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type - remoteLinkAddr tcpip.LinkAddress -} - -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { - t.Helper() - - pi, _ := args.src.ReadContext(context.Background()) - - { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), - }) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) - } - - if pi.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pi.Proto) - return - } - - if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) - } - - // Pull the full payload since network header. Needed for header.IPv6 to - // extract its payload. - ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) - transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) - if transProto != header.ICMPv6ProtocolNumber { - t.Errorf("unexpected transport protocol number %d", transProto) - return - } - icmpv6 := header.ICMPv6(ipv6.Payload()) - if got, want := icmpv6.Type(), args.typ; got != want { - t.Errorf("got ICMPv6 type = %d, want = %d", got, want) - return - } - if fn != nil { - fn(t, icmpv6) - } -} - -func TestLinkResolution(t *testing.T) { - c := newTestContext(t) - - r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: r.LocalAddress, - Dst: r.RemoteAddress, - })) - - // We can't send our payload directly over the route because that - // doesn't provoke NDP discovery. - var wq waiter.Queue - ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) - } - - { - var r bytes.Reader - r.Reset(hdr.View()) - if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) - } - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) - } - - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, - } { - routeICMPv6Packet(t, args, nil) - } -} - -func TestICMPChecksumValidationSimple(t *testing.T) { - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "DstUnreachable", - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - }, - { - name: "PacketTooBig", - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - }, - { - name: "TimeExceeded", - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - }, - { - name: "ParamProblem", - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - }, - { - name: "EchoRequest", - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - }, - { - name: "EchoReply", - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - }, - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - // Hosts MUST silently discard any received Router Solicitation messages. - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } - t.Run(name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr0) - - // Indicate that resolution for link layer addresses is required to - // send packets over this link. This is needed so the NIC knows to - // allocate a neighbor table. - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: lladdr1, - Dst: lladdr0, - })) - } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Router only count should not have increased. - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if !isRouter && typ.routerOnly { - // Router only count should have increased. - if got := routerOnly.Value(); got != 1 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) - } - } - }) - } - } -} - -func TestICMPChecksumValidationWithPayload(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - icmpSize := size + payloadSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) - icmpHdr.SetType(typ) - payloadFn(icmpHdr.Payload()) - - if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: lladdr1, - Dst: lladdr0, - })) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got = %d, want = 0", got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { - const simpleBodySize = 64 - simpleBody := func(view buffer.View) { - for i := 0; i < simpleBodySize; i++ { - view[i] = uint8(i) - } - } - - const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize - errorICMPBody := func(view buffer.View) { - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - TransportProtocol: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, - }) - simpleBody(view[header.IPv6MinimumSize:]) - } - - types := []struct { - name string - typ header.ICMPv6Type - size int - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - payloadSize int - payload func(buffer.View) - }{ - { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.DstUnreachable - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.PacketTooBig - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.TimeExceeded - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.ParamProblem - }, - errorICMPBodySize, - errorICMPBody, - }, - { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoRequest - }, - simpleBodySize, - simpleBody, - }, - { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.EchoReply - }, - simpleBodySize, - simpleBody, - }, - } - - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - icmpHdr := header.ICMPv6(hdr.Prepend(size)) - icmpHdr.SetType(typ) - - payload := buffer.NewView(payloadSize) - payloadFn(payload) - - if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmpHdr, - Src: lladdr1, - Dst: lladdr0, - PayloadCsum: header.Checksum(payload, 0 /* initial */), - PayloadLen: len(payload), - })) - } - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got = %d, want = 0", got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got = %d, want = 0", got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - }) - } -} - -func TestLinkAddressRequest(t *testing.T) { - const nicID = 1 - - snaddr := header.SolicitedNodeAddr(lladdr0) - mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) - - tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - - expectedErr tcpip.Error - expectedRemoteAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - }{ - { - name: "Unicast", - nicAddr: lladdr1, - localAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedRemoteAddr: lladdr0, - expectedRemoteLinkAddr: linkAddr1, - }, - { - name: "Multicast", - nicAddr: lladdr1, - localAddr: lladdr1, - remoteLinkAddr: "", - expectedRemoteAddr: snaddr, - expectedRemoteLinkAddr: mcaddr, - }, - { - name: "Unicast with unspecified source", - nicAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedRemoteAddr: lladdr0, - expectedRemoteLinkAddr: linkAddr1, - }, - { - name: "Multicast with unspecified source", - nicAddr: lladdr1, - remoteLinkAddr: "", - expectedRemoteAddr: snaddr, - expectedRemoteLinkAddr: mcaddr, - }, - { - name: "Unicast with unassigned address", - localAddr: lladdr1, - remoteLinkAddr: linkAddr1, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "Multicast with unassigned address", - localAddr: lladdr1, - remoteLinkAddr: "", - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "Unicast with no local address available", - remoteLinkAddr: linkAddr1, - expectedErr: &tcpip.ErrNetworkUnreachable{}, - }, - { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: &tcpip.ErrNetworkUnreachable{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) - } - linkRes, ok := ep.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) - } - - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) - } - } - - { - err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) - } - } - - if test.expectedErr != nil { - return - } - - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = test.expectedRemoteLinkAddr - if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) - }) - } -} - -func TestPacketQueing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::1").To16()), - PrefixLen: 64, - }, - } - host2IPv6Addr = tcpip.ProtocolAddress{ - Protocol: ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("a::2").To16()), - PrefixLen: 64, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: udp.ProtocolNumber, - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: host2IPv6Addr.AddressWithPrefix.Address, - Dst: host1IPv6Addr.AddressWithPrefix.Address, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply), - checker.ICMPv6Code(header.ICMPv6UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a neighbor solicitation since link address resolution should - // be performed. - { - p, ok := e.ReadContext(context.Background()) - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != ProtocolNumber { - t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) - } - snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), - )) - } - - // Send a neighbor advertisement to complete link address resolution. - { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), - }) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: host2IPv6Addr.AddressWithPrefix.Address, - Dst: host1IPv6Addr.AddressWithPrefix.Address, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, - }) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} - -func TestCallsToNeighborCache(t *testing.T) { - tests := []struct { - name string - createPacket func() header.ICMPv6 - multicast bool - source tcpip.Address - destination tcpip.Address - wantProbeCount int - wantConfirmationCount int - }{ - { - name: "Unicast Neighbor Solicitation without source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - return icmp - }, - source: lladdr1, - destination: lladdr0, - // "The source link-layer address option SHOULD be included in unicast - // solicitations." - RFC 4861 section 4.3 - // - // A Neighbor Advertisement needs to be sent in response, but the - // Neighbor Cache shouldn't be updated since we have no useful - // information about the sender. - wantProbeCount: 0, - }, - { - name: "Unicast Neighbor Solicitation with source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: lladdr0, - wantProbeCount: 1, - }, - { - name: "Multicast Neighbor Solicitation without source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - return icmp - }, - source: lladdr1, - destination: header.SolicitedNodeAddr(lladdr0), - // "The source link-layer address option MUST be included in multicast - // solicitations." - RFC 4861 section 4.3 - wantProbeCount: 0, - }, - { - name: "Multicast Neighbor Solicitation with source link-layer address option", - createPacket: func() header.ICMPv6 { - nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(nsSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: header.SolicitedNodeAddr(lladdr0), - wantProbeCount: 1, - }, - { - name: "Unicast Neighbor Advertisement without target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - return icmp - }, - source: lladdr1, - destination: lladdr0, - // "When responding to unicast solicitations, the target link-layer - // address option can be omitted since the sender of the solicitation has - // the correct link-layer address; otherwise, it would not be able to - // send the unicast solicitation in the first place." - // - RFC 4861 section 4.4 - wantConfirmationCount: 1, - }, - { - name: "Unicast Neighbor Advertisement with target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: lladdr0, - wantConfirmationCount: 1, - }, - { - name: "Multicast Neighbor Advertisement without target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(false) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - return icmp - }, - source: lladdr1, - destination: header.IPv6AllNodesMulticastAddress, - // "Target link-layer address MUST be included for multicast solicitations - // in order to avoid infinite Neighbor Solicitation "recursion" when the - // peer node does not have a cache entry to return a Neighbor - // Advertisements message." - RFC 4861 section 4.4 - wantConfirmationCount: 0, - }, - { - name: "Multicast Neighbor Advertisement with target link-layer address option", - createPacket: func() header.ICMPv6 { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - icmp := header.ICMPv6(buffer.NewView(naSize)) - icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - na.SetSolicitedFlag(false) - na.SetOverrideFlag(false) - na.SetTargetAddress(lladdr1) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - return icmp - }, - source: lladdr1, - destination: header.IPv6AllNodesMulticastAddress, - wantConfirmationCount: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)} - ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: test.source, - Dst: test.destination, - })) - handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false) - - // Confirm the endpoint calls the correct NUDHandler method. - if testInterface.probeCount != test.wantProbeCount { - t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount) - } - if testInterface.confirmationCount != test.wantConfirmationCount { - t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go new file mode 100644 index 000000000..675fdc220 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go @@ -0,0 +1,126 @@ +// automatically generated by stateify. + +package ipv6 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv6DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationUnreachableSockError" +} + +func (i *icmpv6DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv6DestinationUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv6DestinationUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationNetworkUnreachableSockError" +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationNetworkUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationNetworkUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationPortUnreachableSockError" +} + +func (i *icmpv6DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationPortUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationPortUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6DestinationAddressUnreachableSockError" +} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateFields() []string { + return []string{ + "icmpv6DestinationUnreachableSockError", + } +} + +func (i *icmpv6DestinationAddressUnreachableSockError) beforeSave() {} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (i *icmpv6DestinationAddressUnreachableSockError) afterLoad() {} + +func (i *icmpv6DestinationAddressUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError) +} + +func (e *icmpv6PacketTooBigSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv6.icmpv6PacketTooBigSockError" +} + +func (e *icmpv6PacketTooBigSockError) StateFields() []string { + return []string{ + "mtu", + } +} + +func (e *icmpv6PacketTooBigSockError) beforeSave() {} + +func (e *icmpv6PacketTooBigSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.mtu) +} + +func (e *icmpv6PacketTooBigSockError) afterLoad() {} + +func (e *icmpv6PacketTooBigSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.mtu) +} + +func init() { + state.Register((*icmpv6DestinationUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationNetworkUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv6DestinationAddressUnreachableSockError)(nil)) + state.Register((*icmpv6PacketTooBigSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go deleted file mode 100644 index 266a53e3b..000000000 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ /dev/null @@ -1,3152 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "bytes" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // The least significant 3 bytes are the same as addr2 so both addr2 and - // addr3 will have the same solicited-node address. - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" - addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" - - // Tests use the extension header identifier values as uint8 instead of - // header.IPv6ExtensionHeaderIdentifier. - hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) - routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier) - fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) - destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) - noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) - unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) - - extraHeaderReserve = 50 -) - -// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the -// expected Neighbor Advertisement received count after receiving the packet. -func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: src, - Dst: dst, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stats := s.Stats().ICMP.V6.PacketsReceived - - if got := stats.NeighborAdvert.Value(); got != want { - t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) - } -} - -// testReceiveUDP tests receiving a UDP packet from src to dst. want is the -// expected UDP received count after receiving the packet. -func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - - // Receive UDP Packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - - // UDP pseudo-header checksum. - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) - - // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - stat := s.Stats().UDP.PacketsReceived - - if got := stat.Value(); got != want { - t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) - } -} - -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { - // sourcePacket does not have its IP Header populated. Let's copy the one - // from the first fragment. - source := header.IPv6(packets[0].NetworkHeader().View()) - sourceIPHeadersLen := len(source) - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) - - var reassembledPayload buffer.VectorisedView - for i, fragment := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) - fragmentIPHeaders := header.IPv6(allBytes.ToView()) - if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) - } - - fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() - if fragmentIPHeadersLength != sourceIPHeadersLen { - return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) - } - - if got := len(fragmentIPHeaders); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) - } - - sourceIPHeader := source[:header.IPv6MinimumSize] - fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] - - if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { - return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) - } - - // We expect the IPv6 Header to be similar across each fragment, besides the - // payload length. - sourceIPHeader.SetPayloadLength(0) - fragmentIPHeader.SetPayloadLength(0) - if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - - if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } - if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) - } - - if len(packets) > 1 { - // If the source packet was big enough that it needed fragmentation, let's - // inspect the fragment header. Because no other extension headers are - // supported, it will always be the last extension header. - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) - - if got := fragmentHeader.More(); got != wantFragments[i].more { - return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) - } - if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { - return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) - } - if got := fragmentHeader.NextHeader(); got != uint8(proto) { - return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) - } - } - - // Store the reassembled payload as we parse each fragment. The payload - // includes the Transport header and everything after. - reassembledPayload.AppendView(fragment.TransportHeader().View()) - reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView()) - } - - if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and -// UDP packets destined to the IPv6 link-local all-nodes multicast address. -func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocolFactory - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6, testReceiveICMP}, - {"UDP", udp.NewProtocol, testReceiveUDP}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, - }) - e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Should receive a packet destined to the all-nodes - // multicast address. - test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) - }) - } -} - -// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP -// packets destined to the IPv6 solicited-node address of an assigned IPv6 -// address. -func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - tests := []struct { - name string - protocolFactory stack.TransportProtocolFactory - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) - }{ - {"ICMP", icmp.NewProtocol6, testReceiveICMP}, - {"UDP", udp.NewProtocol, testReceiveUDP}, - } - - snmc := header.SolicitedNodeAddr(addr2) - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, - }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - // Should not receive a packet destined to the solicited node address of - // addr2/addr3 yet as we haven't added those addresses. - test.rxf(t, s, e, addr1, snmc, 0) - - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - // Should receive a packet destined to the solicited node address of - // addr2/addr3 now that we have added added addr2. - test.rxf(t, s, e, addr1, snmc, 1) - - if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) - } - - // Should still receive a packet destined to the solicited node address of - // addr2/addr3 now that we have added addr3. - test.rxf(t, s, e, addr1, snmc, 2) - - if err := s.RemoveAddress(nicID, addr2); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) - } - - // Should still receive a packet destined to the solicited node address of - // addr2/addr3 now that we have removed addr2. - test.rxf(t, s, e, addr1, snmc, 3) - - // Make sure addr3's endpoint does not get removed from the NIC by - // incrementing its reference count with a route. - r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) - } - defer r.Release() - - if err := s.RemoveAddress(nicID, addr3); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) - } - - // Should not receive a packet destined to the solicited node address of - // addr2/addr3 yet as both of them got removed, even though a route using - // addr3 exists. - test.rxf(t, s, e, addr1, snmc, 3) - }) - } -} - -// TestAddIpv6Address tests adding IPv6 addresses. -func TestAddIpv6Address(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - }{ - // This test is in response to b/140943433. - { - "Nil", - tcpip.Address([]byte(nil)), - }, - { - "ValidUnicast", - addr1, - }, - { - "ValidLinkLocalUnicast", - lladdr0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { - t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) - } - - if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { - t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) - } else if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) - } - }) - } -} - -func TestReceiveIPv6ExtHdrs(t *testing.T) { - tests := []struct { - name string - extHdr func(nextHdr uint8) ([]byte, uint8) - shouldAccept bool - countersToBeIncremented func(*tcpip.Stats) []*tcpip.StatCounter - // Should we expect an ICMP response and if so, with what contents? - expectICMP bool - ICMPType header.ICMPv6Type - ICMPCode header.ICMPv6Code - pointer uint32 - multicast bool - }{ - { - name: "None", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "hopbyhop with router alert option", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - }, hopByHopExtHdrID - }, - shouldAccept: true, - countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter { - return []*tcpip.StatCounter{stats.IP.OptionRouterAlertReceived} - }, - }, - { - name: "hopbyhop with two router alert options", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, - - // Router Alert option. - 5, 2, 0, 0, 0, 0, 0, 0, - }, hopByHopExtHdrID - }, - shouldAccept: false, - countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter { - return []*tcpip.StatCounter{ - stats.IP.OptionRouterAlertReceived, - stats.IP.MalformedPacketsReceived, - } - }, - }, - { - name: "hopbyhop with unknown option skippable action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "hopbyhop with unknown option discard action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "hopbyhop with unknown option discard and send icmp action (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ Unknown option. - }, hopByHopExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: false, - }, - { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 0, 2, 3, 4, 5, - }, routingExtHdrID - }, - shouldAccept: true, - }, - { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 1, 2, 3, 4, 5, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6ErroneousHeader, - pointer: header.IPv6FixedHeaderSize + 2, - }, - { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 0, 0, 0, 0, 0, 0, - }, fragmentExtHdrID - }, - shouldAccept: true, - }, - { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 0, 0, 1, 2, 3, 4, - }, fragmentExtHdrID - }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, - 1, 0, 1, 2, 3, 4, - }, fragmentExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{}, - noNextHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "unknown next header (first)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, unknownHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6NextHeaderOffset, - }, - { - name: "unknown next header (not first)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - unknownHdrID, 0, - 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "destination with unknown option skippable action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 62, 6, 1, 2, 3, 4, 5, 6, - }, destinationExtHdrID - }, - shouldAccept: true, - expectICMP: false, - }, - { - name: "destination with unknown option discard action", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard unknown. - 127, 6, 1, 2, 3, 4, 5, 6, - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "destination with unknown option discard and send icmp action (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ 191 is an unknown option. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action (muilticast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP if option is unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - //^ 191 is an unknown option. - }, destinationExtHdrID - }, - multicast: true, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ 255 is unknown. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownOption, - pointer: header.IPv6FixedHeaderSize + 8, - }, - { - name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Discard & send ICMP unless packet is for multicast destination if - // option is unknown. - 255, 6, 1, 2, 3, 4, 5, 6, - //^ 255 is unknown. - }, destinationExtHdrID - }, - shouldAccept: false, - expectICMP: false, - multicast: true, - }, - { - name: "atomic fragment - routing", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Fragment extension header. - routingExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Routing extension header. - nextHdr, 0, 1, 0, 2, 3, 4, 5, - }, fragmentExtHdrID - }, - shouldAccept: true, - }, - { - name: "hop by hop (with skippable unknown) - routing", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - nextHdr, 0, 1, 0, 2, 3, 4, 5, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "routing - hop by hop (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Routing extension header. - hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, - // ^^^ The HopByHop extension header may not appear after the first - // extension header. - - // Hop By Hop extension header with skippable unknown option. - nextHdr, 0, 62, 4, 1, 2, 3, 4, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "routing - hop by hop (with send icmp unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Routing extension header. - hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, - // ^^^ The HopByHop extension header may not appear after the first - // extension header. - - nextHdr, 1, - - // Skippable unknown. - 63, 4, 1, 2, 3, 4, - - // Skippable unknown. - 191, 6, 1, 2, 3, 4, 5, 6, - }, routingExtHdrID - }, - shouldAccept: false, - expectICMP: true, - ICMPType: header.ICMPv6ParamProblem, - ICMPCode: header.ICMPv6UnknownHeader, - pointer: header.IPv6FixedHeaderSize, - }, - { - name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with skippable unknown option. - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: true, - }, - { - name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with discard action for unknown option. - routingExtHdrID, 0, 65, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with skippable unknown option. - nextHdr, 0, 63, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - { - name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", - extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{ - // Hop By Hop extension header with skippable unknown option. - routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, - - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, - - // Fragment extension header. - destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, - - // Destination extension header with discard action for unknown - // option. - nextHdr, 0, 65, 4, 1, 2, 3, 4, - }, hopByHopExtHdrID - }, - shouldAccept: false, - expectICMP: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - // Add a default route so that a return packet knows where to go. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8} - udpLength := header.UDPMinimumSize + len(udpPayload) - extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber)) - extHdrLen := len(extHdrBytes) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength) - - // Serialize UDP message. - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), udpPayload) - - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } - - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) - sum = header.Checksum(udpPayload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - // Copy extension header bytes between the UDP message and the IPv6 - // fixed header. - copy(hdr.Prepend(extHdrLen), extHdrBytes) - - // Serialize IPv6 fixed header. - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - // We're lying about transport protocol here to be able to generate - // raw extension headers from the test definitions. - TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, - }) - - stats := s.Stats() - var counters []*tcpip.StatCounter - // Make sure that the counters we expect to be incremented are initially - // set to zero. - if fn := test.countersToBeIncremented; fn != nil { - counters = fn(&stats) - } - for i := range counters { - if got := counters[i].Value(); got != 0 { - t.Errorf("before writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 0", i, got) - } - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - for i := range counters { - if got := counters[i].Value(); got != 1 { - t.Errorf("after writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 1", i, got) - } - } - - udpReceiveStat := stats.UDP.PacketsReceived - if !test.shouldAccept { - if got := udpReceiveStat.Value(); got != 0 { - t.Errorf("got UDP Rx Packets = %d, want = 0", got) - } - - if !test.expectICMP { - if p, ok := e.Read(); ok { - t.Fatalf("unexpected packet received: %#v", p) - } - return - } - - // ICMP required. - p, ok := e.Read() - if !ok { - t.Fatalf("expected packet wasn't written out") - } - - // Pack the output packet into a single buffer.View as the checkers - // assume that. - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { - t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) - } - - ipHdr := header.IPv6(pkt) - checker.IPv6(t, ipHdr, checker.ICMPv6( - checker.ICMPv6Type(test.ICMPType), - checker.ICMPv6Code(test.ICMPCode))) - - // We know we are looking at no extension headers in the error ICMP - // packets. - icmpPkt := header.ICMPv6(ipHdr.Payload()) - // We know we sent small packets that won't be truncated when reflected - // back to us. - originalPacket := icmpPkt.Payload() - if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { - t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) - } - if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { - t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) - } - return - } - - // Expect a UDP packet. - if got := udpReceiveStat.Value(); got != 1 { - t.Errorf("got UDP Rx Packets = %d, want = 1", got) - } - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read: %s", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(udpPayload), - Total: len(udpPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - - // Should not have any more UDP packets. - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// fragmentData holds the IPv6 payload for a fragmented IPv6 packet. -type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - nextHdr uint8 - data buffer.VectorisedView -} - -func TestReceiveIPv6Fragments(t *testing.T) { - const ( - udpPayload1Length = 256 - udpPayload2Length = 128 - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 - fragmentExtHdrLen = 8 - // Note, not all routing extension headers will be 8 bytes but this test - // uses 8 byte routing extension headers for most sub tests. - routingExtHdrLen = 8 - ) - - udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View { - payloadLen := len(payload) - for i := 0; i < payloadLen; i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + payloadLen - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte - udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:] - ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2) - - var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte - udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:] - ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2) - - var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte - udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:] - ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2) - - var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte - udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] - ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) - - var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte - udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] - ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) - - tests := []struct { - name string - expectedPayload []byte - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: uint8(header.UDPProtocolNumber), - data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Atomic fragment", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Atomic fragment with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), - - ipv6Payload3Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with different Next Header values", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - // NextHeader value is different than the one in the first fragment, so - // this NextHeader should be ignored. - buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+63, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[:63], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload3Addr1ToAddr2[63:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - udpMaximumSizeMinus15 >> 8, - udpMaximumSizeMinus15 & 0xff, - 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - udpMaximumSizeMinus15 >> 8, - (udpMaximumSizeMinus15 & 0xff) + 1, - 0, 0, 0, 1}), - - ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with per-fragment routing header with zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with per-fragment routing header with non-zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: routingExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Routing extension header. - // - // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), - - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header. - // - // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with routing header with non-zero segments left", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header. - // - // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with zero segments left across fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is fragmentExtHdrLen+8 because the - // first 8 bytes of the 16 byte routing extension header is in - // this fragment. - fragmentExtHdrLen+8, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header (part 1) - // - // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of - // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), - - // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with routing header with non-zero segments left across fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is fragmentExtHdrLen+8 because the - // first 8 bytes of the 16 byte routing extension header is in - // this fragment. - fragmentExtHdrLen+8, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), - - // Routing extension header (part 1) - // - // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of - // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), - - // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - - ipv6Payload1Addr1ToAddr2, - }, - ), - }, - }, - expectedPayloads: nil, - }, - // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal" - // fragmented traffic. - { - name: "Two fragments with atomic", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - // This fragment has the same ID as the other fragments but is an atomic - // fragment. It should not interfere with the other fragments. - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2), - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), - - ipv6Payload2Addr1ToAddr2, - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), - - ipv6Payload2Addr1ToAddr2[:32], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 4, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), - - ipv6Payload2Addr1ToAddr2[32:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[:64], - }, - ), - }, - { - srcAddr: addr3, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - - ipv6Payload1Addr3ToAddr2[:32], - }, - ), - }, - { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - - ipv6Payload1Addr1ToAddr2[64:], - }, - ), - }, - { - srcAddr: addr3, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32, - []buffer.View{ - // Fragment extension header. - // - // Fragment offset = 4, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), - - ipv6Payload1Addr3ToAddr2[32:], - }, - ), - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) - - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - // We're lying about transport protocol here so that we can generate - // raw extension headers for the tests. - TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) - - vv := hdr.View().ToVectorisedView() - vv.Append(f.data) - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, p := range test.expectedPayloads { - var buf bytes.Buffer - _, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) Read: %s", i, err) - } - if diff := cmp.Diff(p, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestInvalidIPv6Fragments(t *testing.T) { - const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - nicID = 1 - hoplimit = 255 - ident = 1 - data = "TEST_INVALID_IPV6_FRAGMENTS" - ) - - type fragmentData struct { - ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6SerializableFragmentExtHdr - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - expectICMP bool - expectICMPType header.ICMPv6Type - expectICMPCode header.ICMPv6Code - expectICMPTypeSpecific uint32 - }{ - { - name: "fragment size is not a multiple of 8 and the M flag is true", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0 >> 3, - M: true, - Identification: ident, - }, - payload: []byte(data)[:9], - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - expectICMP: true, - expectICMPType: header.ICMPv6ParamProblem, - expectICMPCode: header.ICMPv6ErroneousHeader, - expectICMPTypeSpecific: header.IPv6PayloadLenOffset, - }, - { - name: "fragments reassembled into a payload exceeding the max IPv6 payload size", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, - M: false, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - expectICMP: true, - expectICMPType: header.ICMPv6ParamProblem, - expectICMPCode: header.ICMPv6ErroneousHeader, - expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - NewProtocol, - }, - }) - e := channel.New(1, 1500, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }}) - - var expectICMPPayload buffer.View - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - encodeArgs := f.ipv6Fields - encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) - ip.Encode(&encodeArgs) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if test.expectICMP { - expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(ProtocolNumber, pkt) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) - } - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), - checker.ICMPv6( - checker.ICMPv6Type(test.expectICMPType), - checker.ICMPv6Code(test.expectICMPCode), - checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), - checker.ICMPv6Payload([]byte(expectICMPPayload)), - ), - ) - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - nicID = 1 - hoplimit = 255 - ident = 1 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6SerializableFragmentExtHdr - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 8, - M: false, - Identification: ident, - }, - payload: []byte(data)[16:], - }, - { - ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, - }, - ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ - FragmentOffset: 0, - M: true, - Identification: ident, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - NewProtocol, - }, - Clock: clock, - }) - - e := channel.New(1, 1500, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - encodeArgs := f.ipv6Fields - encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) - ip.Encode(&encodeArgs) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(ProtocolNumber, pkt) - } - - clock.Advance(ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), - checker.ICMPv6Payload([]byte(firstFragmentSent)), - ), - ) - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - t.Helper() - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, - }, - } - - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) - } - if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ) - if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) - } - { - mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - gso *stack.GSO - transHdrLen int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: header.IPv6MinimumMTU, - gso: nil, - transHdrLen: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: header.IPv6MinimumMTU, - gso: nil, - transHdrLen: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 776, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv6MinimumMTU + 1, - gso: nil, - transHdrLen: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 776, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - gso: nil, - transHdrLen: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with gso none", - mtu: header.IPv6MinimumMTU, - gso: &stack.GSO{Type: stack.GSONone}, - transHdrLen: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 176, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: header.IPv6MinimumMTU, - gso: nil, - transHdrLen: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 76, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ( - ttl = 42 - tos = stack.DefaultTOS - transportProto = tcp.ProtocolNumber - ) - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - source := pkt.Clone() - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("WritePacket(_, _, _): = %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - tests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - source := pkt - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if n != wantTotalPackets || err != nil { - t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transHdrLen int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transHdrLen: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 1300, - payloadSize: 3000, - transHdrLen: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 1500, - payloadSize: 4000, - transHdrLen: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than transport header", - mtu: header.IPv6MinimumMTU, - transHdrLen: 1500, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrMessageTooLong{}, - }, - { - description: "Error when MTU is smaller than IPv6 minimum MTU", - mtu: header.IPv6MinimumMTU - 1, - transHdrLen: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - randomSequence = 123 - randomIdent = 42 - ) - - ipv6Addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10::1").To16()), - PrefixLen: 64, - } - ipv6Addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11::1").To16()), - PrefixLen: 64, - } - remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) - remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) - - tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - }{ - { - name: "TTL of zero", - TTL: 0, - expectErrorICMP: true, - }, - { - name: "TTL of one", - TTL: 1, - expectErrorICMP: true, - }, - { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, - }, - { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} - if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} - if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: ipv6Addr1.Subnet(), - NIC: nicID1, - }, - { - Destination: ipv6Addr2.Subnet(), - NIC: nicID2, - }, - }) - - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) - } - - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) - icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv6EchoRequest) - icmp.SetCode(header.ICMPv6UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: remoteIPv6Addr1, - Dst: remoteIPv6Addr2, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, - }) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e1.InjectInbound(ProtocolNumber, requestPkt) - - if test.expectErrorICMP { - reply, ok := e1.Read() - if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") - } - - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), - checker.TTL(DefaultTTL), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), - checker.ICMPv6Payload([]byte(hdr.View())), - ), - ) - - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) - } - } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo Request packet through outgoing NIC") - } - - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), - checker.TTL(test.TTL-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest), - checker.ICMPv6Code(header.ICMPv6UnusedCode), - checker.ICMPv6Payload(nil), - ), - ) - - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } - } - }) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // supposed to be bound. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go deleted file mode 100644 index 9a425e50a..000000000 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" -) - -var ( - linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) - globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) -) - -func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { - t.Helper() - - checker.IPv6WithExtHdr(t, p, - checker.IPv6ExtHdr( - checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), - ), - checker.SrcAddr(localAddress), - checker.DstAddr(remoteAddress), - checker.TTL(header.MLDHopLimit), - checker.MLD(mldType, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(groupAddress), - ), - ) -} - -func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // The stack will join an address's solicited node multicast address when - // an address is added. An MLD report message should be sent for the - // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - - // The stack will leave an address's solicited node multicast address when - // an address is removed. An MLD done message should be sent for the - // solicited-node group. - if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a done message to be sent") - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) - } -} - -func TestSendQueuedMLDReports(t *testing.T) { - const ( - nicID = 1 - maxReports = 2 - ) - - tests := []struct { - name string - dadTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD Disabled", - dadTransmits: 0, - retransmitTimer: 0, - }, - { - name: "DAD Enabled", - dadTransmits: 1, - retransmitTimer: time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dadTransmits, - RetransmitTimer: test.retransmitTimer, - }, - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - Clock: clock, - }) - - // Allow space for an extra packet so we can observe packets that were - // unexpectedly sent. - e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - resolveDAD := func(addr, snmc tcpip.Address) { - clock.Advance(dadResolutionTime) - if p, ok := e.Read(); !ok { - t.Fatal("expected DAD packet") - } else { - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr), - checker.NDPNSOptions(nil), - )) - } - } - - var reportCounter uint64 - reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - var doneCounter uint64 - doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - if got := doneStat.Value(); got != doneCounter { - t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) - } - - // Joining a group without an assigned address should send an MLD report - // with the unspecified address. - if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) - } - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", globalMulticastAddr) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Adding a global address should not send reports for the already joined - // group since we should only send queued reports when a link-local - // address is assigned. - // - // Note, we will still expect to send a report for the global address's - // solicited node address from the unspecified address as per RFC 3590 - // section 4. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) - } - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", globalAddrSNMC) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) - } - if dadResolutionTime != 0 { - // Reports should not be sent when the address resolves. - resolveDAD(globalAddr, globalAddrSNMC) - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - } - // Leave the group since we don't care about the global address's - // solicited node multicast group membership. - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) - } - if got := doneStat.Value(); got != doneCounter { - t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) - } - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Adding a link-local address should send a report for its solicited node - // address and globalMulticastAddr. - if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) - } - if dadResolutionTime != 0 { - reportCounter++ - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) - } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) - } - resolveDAD(linkLocalAddr, linkLocalAddrSNMC) - } - - // We expect two batches of reports to be sent (1 batch when the - // link-local address is assigned, and another after the maximum - // unsolicited report interval. - for i := 0; i < 2; i++ { - // We expect reports to be sent (one for globalMulticastAddr and another - // for linkLocalAddrSNMC). - reportCounter += maxReports - if got := reportStat.Value(); got != reportCounter { - t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) - } - - addrs := map[tcpip.Address]bool{ - globalMulticastAddr: false, - linkLocalAddrSNMC: false, - } - for range addrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) - } - - addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() - if seen, ok := addrs[addr]; !ok { - t.Fatalf("got unexpected packet destined to %s", addr) - } else if seen { - t.Fatalf("got another packet destined to %s", addr) - } - - addrs[addr] = true - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) - - clock.Advance(ipv6.UnsolicitedReportIntervalMax) - } - } - - // Should not send any more reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } - }) - } -} - -// createAndInjectMLDPacket creates and injects an MLD packet with the -// specified fields. -func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, hopLimit uint8, srcAddress tcpip.Address, withRouterAlertOption bool, routerAlertValue header.IPv6RouterAlertValue) { - var extensionHeaders header.IPv6ExtHdrSerializer - if withRouterAlertOption { - extensionHeaders = header.IPv6ExtHdrSerializer{ - header.IPv6SerializableHopByHopExtHdr{ - &header.IPv6RouterAlertOption{Value: routerAlertValue}, - }, - } - } - - extensionHeadersLength := extensionHeaders.Length() - payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize - buf := buffer.NewView(header.IPv6MinimumSize + payloadLength) - - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - HopLimit: hopLimit, - TransportProtocol: header.ICMPv6ProtocolNumber, - SrcAddr: srcAddress, - DstAddr: header.IPv6AllNodesMulticastAddress, - ExtensionHeaders: extensionHeaders, - }) - - icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) - icmp.SetType(mldType) - mld := header.MLD(icmp.MessageBody()) - mld.SetMaximumResponseDelay(0) - mld.SetMulticastAddress(header.IPv6Any) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: srcAddress, - Dst: header.IPv6AllNodesMulticastAddress, - })) - - e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -func TestMLDPacketValidation(t *testing.T) { - const ( - nicID = 1 - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ) - - tests := []struct { - name string - messageType header.ICMPv6Type - srcAddr tcpip.Address - includeRouterAlertOption bool - routerAlertValue header.IPv6RouterAlertValue - hopLimit uint8 - expectValidMLD bool - getMessageTypeStatValue func(tcpip.Stats) uint64 - }{ - { - name: "valid", - messageType: header.ICMPv6MulticastListenerQuery, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertMLD, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit, - expectValidMLD: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerQuery.Value() }, - }, - { - name: "bad hop limit", - messageType: header.ICMPv6MulticastListenerReport, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertMLD, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit + 1, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() }, - }, - { - name: "src ip not link local", - messageType: header.ICMPv6MulticastListenerReport, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertMLD, - srcAddr: globalAddr, - hopLimit: header.MLDHopLimit, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() }, - }, - { - name: "missing router alert ip option", - messageType: header.ICMPv6MulticastListenerDone, - includeRouterAlertOption: false, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() }, - }, - { - name: "incorrect router alert value", - messageType: header.ICMPv6MulticastListenerDone, - includeRouterAlertOption: true, - routerAlertValue: header.IPv6RouterAlertRSVP, - srcAddr: linkLocalAddr2, - hopLimit: header.MLDHopLimit, - expectValidMLD: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: true, - }, - })}, - }) - e := channel.New(nicID, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - stats := s.Stats() - // Verify that every relevant stats is zero'd before we send a packet. - if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) - } - if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != 0 { - t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = 0", got) - } - if got := stats.IP.PacketsDelivered.Value(); got != 0 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) - } - createAndInjectMLDPacket(e, test.messageType, test.hopLimit, test.srcAddr, test.includeRouterAlertOption, test.routerAlertValue) - // We always expect the packet to pass IP validation. - if got := stats.IP.PacketsDelivered.Value(); got != 1 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) - } - // Even when the MLD-specific validation checks fail, we expect the - // corresponding MLD counter to be incremented. - if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) - } - var expectedInvalidCount uint64 - if !test.expectValidMLD { - expectedInvalidCount = 1 - } - if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { - t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go deleted file mode 100644 index 6e850fd46..000000000 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ /dev/null @@ -1,1362 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "context" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - -var _ NDPDispatcher = (*testNDPDispatcher)(nil) - -// testNDPDispatcher is an NDPDispatcher only allows default router discovery. -type testNDPDispatcher struct { - addr tcpip.Address -} - -func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { -} - -func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { - t.addr = addr - return true -} - -func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { - t.addr = addr -} - -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false -} - -func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { -} - -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return false -} - -func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { -} - -func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { -} - -func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { -} - -func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { - var ndpDisp testNDPDispatcher - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) - } - - ipv6EP := ep.(*endpoint) - ipv6EP.mu.Lock() - ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) - ipv6EP.mu.Unlock() - - if ndpDisp.addr != lladdr1 { - t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) - } - - ndpDisp.addr = "" - ndpEP := ep.(stack.NDPEndpoint) - ndpEP.InvalidateDefaultRouter(lladdr1) - if ndpDisp.addr != lladdr1 { - t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) - } -} - -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - -// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a -// valid NDP NS message with the Source Link Layer Address option results in a -// new entry in the link address cache for the sender of the message. -func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: lladdr1, - Dst: lladdr0, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - neighbors, err := s.Neighbors(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) - } - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) - } - neighborByAddr[n.Addr] = n - } - - if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 { - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - if !ok { - t.Fatalf("expected a neighbor entry for %q", lladdr1) - } - if neigh.LinkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr) - } - if neigh.State != stack.Stale { - t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale) - } - } else { - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - - if ok { - t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) - } - } - }) - } -} - -func TestNeighborSolicitationResponse(t *testing.T) { - const nicID = 1 - nicAddr := lladdr0 - remoteAddr := lladdr1 - nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) - nicLinkAddr := linkAddr0 - remoteLinkAddr0 := linkAddr1 - remoteLinkAddr1 := linkAddr2 - - tests := []struct { - name string - nsOpts header.NDPOptionsSerializer - nsSrcLinkAddr tcpip.LinkAddress - nsSrc tcpip.Address - nsDst tcpip.Address - nsInvalid bool - naDstLinkAddr tcpip.LinkAddress - naSolicited bool - naSrc tcpip.Address - naDst tcpip.Address - performsLinkResolution bool - }{ - { - name: "Unspecified source to solicited-node multicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, - }, - { - name: "Unspecified source with source ll option to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - { - name: "Unspecified source to unicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddr, - nsInvalid: true, - }, - { - name: "Unspecified source with source ll option to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: header.IPv6Any, - nsDst: nicAddr, - nsInvalid: true, - }, - { - name: "Specified source with 1 source ll to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 1 source ll different from route to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr1, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source to multicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - { - name: "Specified source with 2 source ll to multicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddrSNMC, - nsInvalid: true, - }, - - { - name: "Specified source to unicast destination", - nsOpts: nil, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - // Since we send a unicast solicitations to a node without an entry for - // the remote, the node needs to perform neighbor discovery to get the - // remote's link address to send the advertisement response. - performsLinkResolution: true, - }, - { - name: "Specified source with 1 source ll to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 1 source ll different from route to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr1, - naSolicited: true, - naSrc: nicAddr, - naDst: remoteAddr, - }, - { - name: "Specified source with 2 source ll to unicast destination", - nsOpts: header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), - header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), - }, - nsSrcLinkAddr: remoteLinkAddr0, - nsSrc: remoteAddr, - nsDst: nicAddr, - nsInvalid: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(1, 1280, nicLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: test.nsSrc, - Dst: test.nsDst, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } - - // If we expected the NS to be invalid, we have nothing else to check. - return - } - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NS response") - } - - respNSDst := header.SolicitedNodeAddr(test.nsSrc) - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) - if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(nicAddr), - checker.DstAddr(respNSDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(test.nsSrc), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(nicLinkAddr), - }), - )) - - ser := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(test.nsSrc) - na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: test.nsSrc, - Dst: nicAddr, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, - }) - e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NA response") - } - - if p.Route.LocalAddress != test.naSrc { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } - if p.Route.RemoteAddress != test.naDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) - } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) - }) - } -} - -// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option does not -// result in a new entry in the neighbor cache for the target of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - isValid bool - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - isValid: true, - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - { - name: "Multiple", - optsBuf: []byte{ - 2, 1, 2, 3, 4, 5, 6, 7, - 2, 1, 2, 3, 4, 5, 6, 8, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.MessageBody()) - ns.SetTargetAddress(lladdr1) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: lladdr1, - Dst: lladdr0, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - neighbors, err := s.Neighbors(nicID, ProtocolNumber) - if err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } - - neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) - for _, n := range neighbors { - if existing, ok := neighborByAddr[n.Addr]; ok { - if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) - } - t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) - } - neighborByAddr[n.Addr] = n - } - - if neigh, ok := neighborByAddr[lladdr1]; ok { - t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) - } - - if test.isValid { - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } - - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - var extHdrs header.IPv6ExtHdrSerializer - if atomicFragment { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) - } - extHdrsLen := extHdrs.Length() - - ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + extHdrsLen), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - ExtensionHeaders: extHdrs, - }) - vv := ip.ToVectorisedView() - vv.AppendView(payload) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - var sllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - extraData: sllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } - - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool - }{ - { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, - }, - } - - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } - - t.Run(name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) - - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], - Src: lladdr0, - Dst: lladdr1, - PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), - PayloadLen: len(typ.extraData), - })) - - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - // RouterOnlyPacketsReceivedByHost count should initially be 0. - if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) - - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } - - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) - } - - want = 0 - if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. - want = 1 - } - if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) - } - - }) - } - }) - } - } -} - -// TestNeighborAdvertisementValidation tests that the NIC validates received -// Neighbor Advertisements. -// -// In particular, if the IP Destination Address is a multicast address, and the -// Solicited flag is not zero, the Neighbor Advertisement is invalid and should -// be discarded. -func TestNeighborAdvertisementValidation(t *testing.T) { - tests := []struct { - name string - ipDstAddr tcpip.Address - solicitedFlag bool - valid bool - }{ - { - name: "Multicast IP destination address with Solicited flag set", - ipDstAddr: header.IPv6AllNodesMulticastAddress, - solicitedFlag: true, - valid: false, - }, - { - name: "Multicast IP destination address with Solicited flag unset", - ipDstAddr: header.IPv6AllNodesMulticastAddress, - solicitedFlag: false, - valid: true, - }, - { - name: "Unicast IP destination address with Solicited flag set", - ipDstAddr: lladdr0, - solicitedFlag: true, - valid: true, - }, - { - name: "Unicast IP destination address with Solicited flag unset", - ipDstAddr: lladdr0, - solicitedFlag: false, - valid: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetTargetAddress(lladdr1) - na.SetSolicitedFlag(test.solicitedFlag) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: lladdr1, - Dst: test.ipDstAddr, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: test.ipDstAddr, - }) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxNA := stats.NeighborAdvert - - if got := rxNA.Value(); got != 0 { - t.Fatalf("got rxNA = %d, want = 0", got) - } - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if got := rxNA.Value(); got != 1 { - t.Fatalf("got rxNA = %d, want = 1", got) - } - var wantInvalid uint64 = 1 - if test.valid { - wantInvalid = 0 - } - if got := invalid.Value(); got != wantInvalid { - t.Fatalf("got invalid = %d, want = %d", got, wantInvalid) - } - // As per RFC 4861 section 7.2.5: - // When a valid Neighbor Advertisement is received ... - // If no entry exists, the advertisement SHOULD be silently discarded. - // There is no need to create an entry if none exists, since the - // recipient has apparently not initiated any communication with the - // target. - if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) - } - }) - } -} - -// TestRouterAdvertValidation tests that when the NIC is configured to handle -// NDP Router Advertisement packets, it validates the Router Advertisement -// properly before handling them. -func TestRouterAdvertValidation(t *testing.T) { - tests := []struct { - name string - src tcpip.Address - hopLimit uint8 - code header.ICMPv6Code - ndpPayload []byte - expectedSuccess bool - }{ - { - "OK", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - true, - }, - { - "NonLinkLocalSourceAddr", - addr1, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "HopLimitNot255", - lladdr0, - 254, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NonZeroCode", - lladdr0, - 255, - 1, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - false, - }, - { - "NDPPayloadTooSmall", - lladdr0, - 255, - 0, - []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, - }, - false, - }, - { - "OKWithOptions", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - 2, 1, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - true, - }, - { - "OptionWithZeroLength", - lladdr0, - 255, - 0, - []byte{ - // RA payload - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - - // Option #1 (TargetLinkLayerAddress) - // Invalid as it has 0 length. - 2, 0, 0, 0, 0, 0, 0, 0, - - // Option #2 (unrecognized) - 255, 1, 0, 0, 0, 0, 0, 0, - - // Option #3 (PrefixInformation) - 3, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - }, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.MessageBody(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: test.src, - Dst: header.IPv6AllNodesMulticastAddress, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } - - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - } - }) - } -} - -// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD -// performed when adding new addresses do not interfere with each other. -func TestCheckDuplicateAddress(t *testing.T) { - const nicID = 1 - - clock := faketime.NewManualClock() - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ - DADConfigs: dadConfigs, - })}, - }) - // This test is expected to send at max 2 DAD messages. We allow an extra - // packet to be stored to catch unexpected packets. - e := channel.New(3, header.IPv6MinimumMTU, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - dadPacketsSent := 1 - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - // Start DAD for the address we just added. - // - // Even though the stack will perform DAD before the added address transitions - // from tentative to assigned, this DAD request should be independent of that. - ch := make(chan stack.DADResult, 3) - dadRequestsMade := 1 - dadPacketsSent++ - if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) - } else if res != stack.DADStarting { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) - } - - // Remove the address and make sure our DAD request was not stopped. - if err := s.RemoveAddress(nicID, lladdr0); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) - } - // Should not restart DAD since we already requested DAD above - the handler - // should be called when the original request compeletes so we should not send - // an extra DAD message here. - dadRequestsMade++ - if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) - } else if res != stack.DADAlreadyRunning { - t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning) - } - - // Wait for DAD to complete. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - for i := 0; i < dadRequestsMade; i++ { - if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" { - t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) - } - } - // Should have no more results. - select { - case r := <-ch: - t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - - snmc := header.SolicitedNodeAddr(lladdr0) - remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) - - for i := 0; i < dadPacketsSent; i++ { - p, ok := e.Read() - if !ok { - t.Fatalf("expected %d-th DAD message", i) - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber) - } - - if p.Route.RemoteLinkAddress != remoteLinkAddr { - t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions(nil), - )) - } - - // Should have no more packets. - if p, ok := e.Read(); ok { - t.Errorf("got unexpected packet = %#v", p) - } -} diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go deleted file mode 100644 index ecd5003a7..000000000 --- a/pkg/tcpip/network/multicast_group_test.go +++ /dev/null @@ -1,1274 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "fmt" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - - stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - defaultIPv4PrefixLength = 24 - linkLocalIPv6Addr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalIPv6Addr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - - ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") - ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") - ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") - ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") - ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") - - igmpMembershipQuery = uint8(header.IGMPMembershipQuery) - igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) - igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) - igmpLeaveGroup = uint8(header.IGMPLeaveGroup) - mldQuery = uint8(header.ICMPv6MulticastListenerQuery) - mldReport = uint8(header.ICMPv6MulticastListenerReport) - mldDone = uint8(header.ICMPv6MulticastListenerDone) - - maxUnsolicitedReports = 2 -) - -var ( - // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the - // NIC will wait before sending an unsolicited report after joining a - // multicast group, in deciseconds. - unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { - const decisecond = time.Second / 10 - if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { - panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) - } - return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) - }() - - ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1) -) - -// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet -// sent to the provided address with the passed fields set. -func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv6WithExtHdr(t, payload, - checker.IPv6ExtHdr( - checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), - ), - checker.SrcAddr(linkLocalIPv6Addr1), - checker.DstAddr(remoteAddress), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, - checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), - checker.MLDMulticastAddress(groupAddress), - ), - ) -} - -// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. -func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(stackIPv4Addr), - checker.DstAddr(remoteAddress), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(header.IGMPType(igmpType)), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) - s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) - return e, s, clock -} - -func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { - t.Helper() - - igmpEnabled := v4 && mgpEnabled - mldEnabled := !v4 && mgpEnabled - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - MLD: ipv6.MLDOptions{ - Enabled: mldEnabled, - }, - }), - }, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - addr := tcpip.AddressWithPrefix{ - Address: stackIPv4Addr, - PrefixLen: defaultIPv4PrefixLength, - } - if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) - } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err) - } - - return s, clock -} - -// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join -// when it is created with an IPv6 address. -// -// To not interfere with tests, checkInitialIPv6Groups will leave the added -// address's solicited node multicast group so that the tests can all assume -// the NIC has not joined any IPv6 groups. -func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { - t.Helper() - - stats := s.Stats().ICMP.V6.PacketsSent - - reportCounter++ - if got := stats.MulticastListenerReport.Value(); got != reportCounter { - t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) - } - - // Leave the group to not affect the tests. This is fine since we are not - // testing DAD or the solicited node address specifically. - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) - } - leaveCounter++ - if got := stats.MulticastListenerDone.Value(); got != leaveCounter { - t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - - return reportCounter, leaveCounter -} - -// createAndInjectIGMPPacket creates and injects an IGMP packet with the -// specified fields. -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { - options := header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: header.IGMPTTL, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: remoteIPv4Addr, - DstAddr: header.IPv4AllSystems, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(header.IGMPType(igmpType)) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// createAndInjectMLDPacket creates and injects an MLD packet with the -// specified fields. -func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { - extensionHeaders := header.IPv6ExtHdrSerializer{ - header.IPv6SerializableHopByHopExtHdr{ - &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, - }, - } - - extensionHeadersLength := extensionHeaders.Length() - payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize - buf := buffer.NewView(header.IPv6MinimumSize + payloadLength) - - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - HopLimit: header.MLDHopLimit, - TransportProtocol: header.ICMPv6ProtocolNumber, - SrcAddr: linkLocalIPv6Addr2, - DstAddr: header.IPv6AllNodesMulticastAddress, - ExtensionHeaders: extensionHeaders, - }) - - icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) - icmp.SetType(header.ICMPv6Type(mldType)) - mld := header.MLD(icmp.MessageBody()) - mld.SetMaximumResponseDelay(uint16(maxRespDelay)) - mld.SetMulticastAddress(groupAddress) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: linkLocalIPv6Addr2, - Dst: header.IPv6AllNodesMulticastAddress, - })) - - e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// TestMGPDisabled tests that the multicast group protocol is not enabled by -// default. -func TestMGPDisabled(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - rxQuery func(*channel.Endpoint) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxQuery: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxQuery: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) - - // This NIC may join multicast groups when it is enabled but since MGP is - // disabled, no reports should be sent. - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) - } - - // Inject a general query message. This should only trigger a report to be - // sent if the MGP was enabled. - test.rxQuery(e) - if got := test.receivedQueryStat(s).Value(); got != 1 { - t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) - } - }) - } -} - -func TestMGPReceiveCounters(t *testing.T) { - tests := []struct { - name string - headerType uint8 - maxRespTime byte - groupAddress tcpip.Address - statCounter func(*stack.Stack) *tcpip.StatCounter - rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) - }{ - { - name: "IGMP Membership Query", - headerType: igmpMembershipQuery, - maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, - groupAddress: header.IPv4Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMPv1 Membership Report", - headerType: igmpv1MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.V1MembershipReport - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMPv2 Membership Report", - headerType: igmpv2MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.V2MembershipReport - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "IGMP Leave Group", - headerType: igmpLeaveGroup, - maxRespTime: 0, - groupAddress: header.IPv4AllRoutersGroup, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.LeaveGroup - }, - rxMGPkt: createAndInjectIGMPPacket, - }, - { - name: "MLD Query", - headerType: mldQuery, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxMGPkt: createAndInjectMLDPacket, - }, - { - name: "MLD Report", - headerType: mldReport, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport - }, - rxMGPkt: createAndInjectMLDPacket, - }, - { - name: "MLD Done", - headerType: mldDone, - maxRespTime: 0, - groupAddress: header.IPv6Any, - statCounter: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone - }, - rxMGPkt: createAndInjectMLDPacket, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) - - test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) - if got := test.statCounter(s).Value(); got != 1 { - t.Fatalf("got %s received = %d, want = 1", test.name, got) - } - }) - } -} - -// TestMGPJoinGroup tests that when explicitly joining a multicast group, the -// stack schedules and sends correct Membership Reports. -func TestMGPJoinGroup(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, _ = test.checkInitialGroups(t, e, s, clock) - } - - // Test joining a specific address explicitly and verify a Report is sent - // immediately. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - reportCounter++ - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Verify the second report is sent by the maximum unsolicited response - // interval. - p, ok := e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) - } - clock.Advance(test.maxUnsolicitedResponseDelay) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPLeaveGroup tests that when leaving a previously joined multicast -// group the stack sends a leave/done message. -func TestMGPLeaveGroup(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - reportCounter++ - if got := test.sentReportStat(s).Value(); got != reportCounter { - t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Leaving the group should trigger an leave/done message to be sent. - if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) - } - leaveCounter++ - if got := test.sentLeaveStat(s).Value(); got != leaveCounter { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a leave message to be sent") - } else { - test.validateLeave(t, p) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPQueryMessages tests that a report is sent in response to query -// messages. -func TestMGPQueryMessages(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - receivedQueryStat func(*stack.Stack) *tcpip.StatCounter - rxQuery func(*channel.Endpoint, uint8, tcpip.Address) - validateReport func(*testing.T, channel.PacketInfo) - maxRespTimeToDuration func(uint8) time.Duration - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsReceived.MembershipQuery - }, - rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { - createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - maxRespTimeToDuration: header.DecisecondToDuration, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery - }, - rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { - createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - maxRespTimeToDuration: func(d uint8) time.Duration { - return time.Duration(d) * time.Millisecond - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - subTests := []struct { - name string - multicastAddr tcpip.Address - expectReport bool - }{ - { - name: "Unspecified", - multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), - expectReport: true, - }, - { - name: "Specified", - multicastAddr: test.multicastAddr, - expectReport: true, - }, - { - name: "Specified other address", - multicastAddr: func() tcpip.Address { - addrBytes := []byte(test.multicastAddr) - addrBytes[len(addrBytes)-1]++ - return tcpip.Address(addrBytes) - }(), - expectReport: false, - }, - } - - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, _ = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - sentReportStat := test.sentReportStat(s) - for i := 0; i < maxUnsolicitedReports; i++ { - sentReportStat := test.sentReportStat(s) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatalf("expected %d-th report message to be sent", i) - } else { - test.validateReport(t, p) - } - clock.Advance(test.maxUnsolicitedResponseDelay) - } - if t.Failed() { - t.FailNow() - } - - // Should not send any more packets until a query. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - - // Receive a query message which should trigger a report to be sent at - // some time before the maximum response time if the report is - // targeted at the host. - const maxRespTime = 100 - test.rxQuery(e, maxRespTime, subTest.multicastAddr) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p.Pkt) - } - - if subTest.expectReport { - clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } - }) - } -} - -// TestMGPQueryMessages tests that no further reports or leave/done messages -// are sent after receiving a report. -func TestMGPReportMessages(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - rxReport func(*channel.Endpoint) - validateReport func(*testing.T, channel.PacketInfo) - maxRespTimeToDuration func(uint8) time.Duration - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - rxReport: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) - }, - maxRespTimeToDuration: header.DecisecondToDuration, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - rxReport: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) - }, - validateReport: func(t *testing.T, p channel.PacketInfo) { - t.Helper() - - validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) - }, - maxRespTimeToDuration: func(d uint8) time.Duration { - return time.Duration(d) * time.Millisecond - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - sentReportStat := test.sentReportStat(s) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p) - } - if t.Failed() { - t.FailNow() - } - - // Receiving a report for a group we joined should cancel any further - // reports. - test.rxReport(e) - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); ok { - t.Errorf("sent unexpected packet = %#v", p) - } - if t.Failed() { - t.FailNow() - } - - // Leaving a group after getting a report should not send a leave/done - // message. - if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) - } - clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != leaveCounter { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -func TestMGPWithNICLifecycle(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddrs []tcpip.Address - finalMulticastAddr tcpip.Address - maxUnsolicitedResponseDelay time.Duration - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) - validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) - getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address - checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, - finalMulticastAddr: ipv4MulticastAddr3, - maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.LeaveGroup - }, - validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) - }, - getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { - t.Helper() - - ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { - t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) - } - addr := header.IGMP(ipv4.Payload()).GroupAddress() - s, ok := seen[addr] - if !ok { - t.Fatalf("unexpectedly got a packet for group %s", addr) - } - if s { - t.Fatalf("already saw packet for group %s", addr) - } - seen[addr] = true - return addr - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, - finalMulticastAddr: ipv6MulticastAddr3, - maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone - }, - validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateMLDPacket(t, p, addr, mldReport, 0, addr) - }, - validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { - t.Helper() - - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) - }, - getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { - t.Helper() - - ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - - ipv6HeaderIter := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), - buffer.View(ipv6.Payload()).ToVectorisedView(), - ) - - var transport header.IPv6RawPayloadHeader - for { - h, done, err := ipv6HeaderIter.Next() - if err != nil { - t.Fatalf("ipv6HeaderIter.Next(): %s", err) - } - if done { - t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) - } - if t, ok := h.(header.IPv6RawPayloadHeader); ok { - transport = t - break - } - } - - if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { - t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) - } - icmpv6 := header.ICMPv6(transport.Buf.ToView()) - if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { - t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) - } - addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() - s, ok := seen[addr] - if !ok { - t.Fatalf("unexpectedly got a packet for group %s", addr) - } - if s { - t.Fatalf("already saw packet for group %s", addr) - } - seen[addr] = true - return addr - }, - checkInitialGroups: checkInitialIPv6Groups, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - - var reportCounter uint64 - var leaveCounter uint64 - if test.checkInitialGroups != nil { - reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) - } - - sentReportStat := test.sentReportStat(s) - for _, a := range test.multicastAddrs { - if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) - } - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatalf("expected a report message to be sent for %s", a) - } else { - test.validateReport(t, p, a) - } - } - if t.Failed() { - t.FailNow() - } - - // Leave messages should be sent for the joined groups when the NIC is - // disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - sentLeaveStat := test.sentLeaveStat(s) - leaveCounter += uint64(len(test.multicastAddrs)) - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - { - seen := make(map[tcpip.Address]bool) - for _, a := range test.multicastAddrs { - seen[a] = false - } - - for i := range test.multicastAddrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected (%d-th) leave message to be sent", i) - } - - test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) - } - } - if t.Failed() { - t.FailNow() - } - - // Reports should be sent for the joined groups when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - reportCounter += uint64(len(test.multicastAddrs)) - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - { - seen := make(map[tcpip.Address]bool) - for _, a := range test.multicastAddrs { - seen[a] = false - } - - for i := range test.multicastAddrs { - p, ok := e.Read() - if !ok { - t.Fatalf("expected (%d-th) report message to be sent", i) - } - - test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) - } - } - if t.Failed() { - t.FailNow() - } - - // Joining/leaving a group while disabled should not send any messages. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - leaveCounter += uint64(len(test.multicastAddrs)) - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - for i := range test.multicastAddrs { - if _, ok := e.Read(); !ok { - t.Fatalf("expected (%d-th) leave message to be sent", i) - } - } - for _, a := range test.multicastAddrs { - if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { - t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) - } - if got := sentLeaveStat.Value(); got != leaveCounter { - t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) - } - if p, ok := e.Read(); ok { - t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) - } - } - if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) - } - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); ok { - t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) - } - - // A report should only be sent for the group we last joined after - // enabling the NIC since the original groups were all left. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("EnableNIC(%d): %s", nicID, err) - } - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p, test.finalMulticastAddr) - } - - clock.Advance(test.maxUnsolicitedResponseDelay) - reportCounter++ - if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) - } - if p, ok := e.Read(); !ok { - t.Fatal("expected a report message to be sent") - } else { - test.validateReport(t, p, test.finalMulticastAddr) - } - - // Should not send any more packets. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet = %#v", p) - } - }) - } -} - -// TestMGPDisabledOnLoopback tests that the multicast group protocol is not -// performed on loopback interfaces since they have no neighbours. -func TestMGPDisabledOnLoopback(t *testing.T) { - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - }{ - { - name: "IGMP", - protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().IGMP.PacketsSent.V2MembershipReport - }, - }, - { - name: "MLD", - protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr1, - sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { - return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) - - sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) - } - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) - } - }) - } -} |