diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 67 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp_test.go | 401 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_state_autogen.go | 113 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 3511 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/stats_test.go | 99 |
5 files changed, 113 insertions, 4078 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index 2257f728e..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "igmp.go", - "ipv4.go", - "stats.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/internal/fragmentation", - "//pkg/tcpip/network/internal/ip", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = [ - "igmp_test.go", - "ipv4_test.go", - ], - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/internal/testutil", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stats_test", - size = "small", - srcs = ["stats_test.go"], - library = ":ipv4", - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - ], -) diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go deleted file mode 100644 index c6576fcbc..000000000 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ /dev/null @@ -1,401 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - nicID = 1 - defaultTTL = 1 - defaultPrefixLength = 24 -) - -var ( - stackAddr = testutil.MustParse4("10.0.0.1") - remoteAddr = testutil.MustParse4("10.0.0.2") - multicastAddr = testutil.MustParse4("224.0.0.3") -) - -// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet -// sent to the provided address with the passed fields set. Raises a t.Error if -// any field does not match. -func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) { - t.Helper() - - payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv4(t, payload, - checker.SrcAddr(srcAddr), - checker.DstAddr(dstAddr), - // TTL for an IGMP message must be 1 as per RFC 2236 section 2. - checker.TTL(1), - checker.IPv4RouterAlert(), - checker.IGMP( - checker.IGMPType(igmpType), - checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), - checker.IGMPGroupAddress(groupAddress), - ), - ) -} - -func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { - t.Helper() - - // Create an endpoint of queue size 1, since no more than 1 packets are ever - // queued in the tests in this file. - e := channel.New(1, 1280, linkAddr) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMP: ipv4.IGMPOptions{ - Enabled: igmpEnabled, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - return e, s, clock -} - -func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) { - var options header.IPv4OptionsSerializer - if hasRouterAlertOption { - options = header.IPv4OptionsSerializer{ - &header.IPv4SerializableRouterAlertOption{}, - } - } - buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize) - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: ttl, - Protocol: uint8(header.IGMPProtocolNumber), - SrcAddr: srcAddr, - DstAddr: dstAddr, - Options: options, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - igmp := header.IGMP(ip.Payload()) - igmp.SetType(igmpType) - igmp.SetMaxRespTime(maxRespTime) - igmp.SetGroupAddress(groupAddress) - igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) -} - -// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 -// router is detected. V1 present status is expected to be reset when the NIC -// cycles. -func TestIGMPV1Present(t *testing.T) { - e, s, clock := createStack(t, true) - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}, - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // This NIC will send an IGMPv2 report immediately, before this test can get - // the IGMPv1 General Membership Query in. - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Inject an IGMPv1 General Membership Query which is identical to a standard - // membership query except the Max Response Time is set to 0, which will tell - // the stack that this is a router using IGMPv1. Send it to the all systems - // group which is the only group this host belongs to. - createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */) - if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { - t.Fatalf("got Membership Queries received = %d, want = 1", got) - } - - // Before advancing the clock, verify that this host has not sent a - // V1MembershipReport yet. - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) - } - - // Verify the solicited Membership Report is sent. Now that this NIC has seen - // an IGMPv1 query, it should send an IGMPv1 Membership Report. - if p, ok := e.Read(); ok { - t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - - // Cycling the interface should reset the V1 present flag. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - { - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } -} - -func TestSendQueuedIGMPReports(t *testing.T) { - e, s, clock := createStack(t, true) - - // Joining a group without an assigned address should queue IGMP packets; none - // should be sent without an assigned address. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) - } - reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport - if got := reportStat.Value(); got != 0 { - t.Errorf("got reportStat.Value() = %d, want = 0", got) - } - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } - - // The initial set of IGMP reports that were queued should be sent once an - // address is assigned. - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: stackAddr, - PrefixLen: defaultPrefixLength, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - if got := reportStat.Value(); got != 1 { - t.Errorf("got reportStat.Value() = %d, want = 1", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - if got := reportStat.Value(); got != 2 { - t.Errorf("got reportStat.Value() = %d, want = 2", got) - } - if p, ok := e.Read(); !ok { - t.Error("expected to send an IGMP membership report") - } else { - validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr) - } - if t.Failed() { - t.FailNow() - } - - // Should have no more packets to send after the initial set of unsolicited - // reports. - clock.Advance(time.Hour) - if p, ok := e.Read(); ok { - t.Fatalf("got unexpected packet = %#v", p) - } -} - -func TestIGMPPacketValidation(t *testing.T) { - tests := []struct { - name string - messageType header.IGMPType - stackAddresses []tcpip.AddressWithPrefix - srcAddr tcpip.Address - includeRouterAlertOption bool - ttl uint8 - expectValidIGMP bool - getMessageTypeStatValue func(tcpip.Stats) uint64 - }{ - { - name: "valid", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "bad ttl", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 2, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "missing router alert ip option", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: false, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "igmp leave group and src ip does not belong to nic subnet", - messageType: header.IGMPLeaveGroup, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, - }, - { - name: "igmp query and src ip does not belong to nic subnet", - messageType: header.IGMPMembershipQuery, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, - }, - { - name: "igmp report v1 and src ip does not belong to nic subnet", - messageType: header.IGMPv1MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, - }, - { - name: "igmp report v2 and src ip does not belong to nic subnet", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: testutil.MustParse4("10.0.1.2"), - ttl: 1, - expectValidIGMP: false, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - { - name: "src ip belongs to the subnet of the nic's second address", - messageType: header.IGMPv2MembershipReport, - includeRouterAlertOption: true, - stackAddresses: []tcpip.AddressWithPrefix{ - {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24}, - {Address: stackAddr, PrefixLen: 24}, - }, - srcAddr: remoteAddr, - ttl: 1, - expectValidIGMP: true, - getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) - for _, address := range test.stackAddresses { - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: address, - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - } - stats := s.Stats() - // Verify that every relevant stats is zero'd before we send a packet. - if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got) - } - if got := stats.IP.PacketsDelivered.Value(); got != 0 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) - } - createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption) - // We always expect the packet to pass IP validation. - if got := stats.IP.PacketsDelivered.Value(); got != 1 { - t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) - } - // Even when the IGMP-specific validation checks fail, we expect the - // corresponding IGMP counter to be incremented. - if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { - t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) - } - var expectedInvalidCount uint64 - if !test.expectValidIGMP { - expectedInvalidCount = 1 - } - if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { - t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100644 index 000000000..19b672251 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,113 @@ +// automatically generated by stateify. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError" +} + +func (i *icmpv4DestinationUnreachableSockError) StateFields() []string { + return []string{} +} + +func (i *icmpv4DestinationUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() +} + +func (i *icmpv4DestinationUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv4DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) { +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError" +} + +func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationHostUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError" +} + +func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + } +} + +func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {} + +// +checklocksignore +func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (i *icmpv4DestinationPortUnreachableSockError) afterLoad() {} + +// +checklocksignore +func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError) +} + +func (e *icmpv4FragmentationNeededSockError) StateTypeName() string { + return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError" +} + +func (e *icmpv4FragmentationNeededSockError) StateFields() []string { + return []string{ + "icmpv4DestinationUnreachableSockError", + "mtu", + } +} + +func (e *icmpv4FragmentationNeededSockError) beforeSave() {} + +// +checklocksignore +func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError) + stateSinkObject.Save(1, &e.mtu) +} + +func (e *icmpv4FragmentationNeededSockError) afterLoad() {} + +// +checklocksignore +func (e *icmpv4FragmentationNeededSockError) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError) + stateSourceObject.Load(1, &e.mtu) +} + +func init() { + state.Register((*icmpv4DestinationUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationHostUnreachableSockError)(nil)) + state.Register((*icmpv4DestinationPortUnreachableSockError)(nil)) + state.Register((*icmpv4FragmentationNeededSockError)(nil)) +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index ef91245d7..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,3511 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "encoding/hex" - "fmt" - "io/ioutil" - "math" - "net" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/raw" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - extraHeaderReserve = 50 - defaultMTU = 65536 -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) - if testing.Verbose() { - ep = sniffer.New(ep) - } - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - { - err := ep.Connect(randomAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) - } - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(), - } - if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -func TestForwarding(t *testing.T) { - const ( - incomingNICID = 1 - outgoingNICID = 2 - randomSequence = 123 - randomIdent = 42 - randomTimeOffset = 0x10203040 - ) - - incomingIPv4Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - } - outgoingIPv4Addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), - PrefixLen: 8, - } - outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") - remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") - unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") - multicastIPv4Addr := testutil.MustParse4("225.0.0.0") - linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") - - tests := []struct { - name string - TTL uint8 - sourceAddr tcpip.Address - destAddr tcpip.Address - expectErrorICMP bool - ipFlags uint8 - mtu uint32 - payloadLength int - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code - expectPacketUnrouteableError bool - expectLinkLocalSourceError bool - expectLinkLocalDestError bool - expectPacketForwarded bool - expectedFragmentsForwarded []fragmentInfo - }{ - { - name: "TTL of zero", - TTL: 0, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectErrorICMP: true, - icmpType: header.ICMPv4TimeExceeded, - icmpCode: header.ICMPv4TTLExceeded, - mtu: ipv4.MaxTotalSize, - }, - { - name: "TTL of one", - TTL: 1, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - }, - { - name: "TTL of two", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - }, - { - name: "Max TTL", - TTL: math.MaxUint8, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - }, - { - name: "four EOL options", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - expectPacketForwarded: true, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "TS type 1 full", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - expectErrorICMP: true, - icmpType: header.ICMPv4ParamProblem, - icmpCode: header.ICMPv4UnusedCode, - }, - { - name: "TS type 0", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - forwardedOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - expectPacketForwarded: true, - }, - { - name: "end of options list", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: ipv4.MaxTotalSize, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - 1, 2, 3, 4, - }, - forwardedOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 7 bytes unknown option removed. - 0, 0, 0, 0, - }, - expectPacketForwarded: true, - }, - { - name: "Network unreachable", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: unreachableIPv4Addr, - expectErrorICMP: true, - mtu: ipv4.MaxTotalSize, - icmpType: header.ICMPv4DstUnreachable, - icmpCode: header.ICMPv4NetUnreachable, - expectPacketUnrouteableError: true, - }, - { - name: "Multicast destination", - TTL: 2, - destAddr: multicastIPv4Addr, - expectPacketUnrouteableError: true, - }, - { - name: "Link local destination", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: linkLocalIPv4Addr, - expectLinkLocalDestError: true, - }, - { - name: "Link local source", - TTL: 2, - sourceAddr: linkLocalIPv4Addr, - destAddr: remoteIPv4Addr2, - expectLinkLocalSourceError: true, - }, - { - name: "Fragmentation needed and DF set", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - ipFlags: header.IPv4FlagDontFragment, - // We've picked this MTU because it is: - // - // 1) Greater than the minimum MTU that IPv4 hosts are required to process - // (576 bytes). As per RFC 1812, Section 4.3.2.3: - // - // The ICMP datagram SHOULD contain as much of the original datagram as - // possible without the length of the ICMP datagram exceeding 576 bytes. - // - // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a - // complete ICMP packet on the incoming endpoint (and make assertions about - // it). - // - // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose - // size exceeds the MTU. - mtu: 1000, - payloadLength: 1004, - expectErrorICMP: true, - icmpType: header.ICMPv4DstUnreachable, - icmpCode: header.ICMPv4FragmentationNeeded, - }, - { - name: "Fragmentation needed and DF not set", - TTL: 2, - sourceAddr: remoteIPv4Addr1, - destAddr: remoteIPv4Addr2, - mtu: 1000, - payloadLength: 1004, - expectPacketForwarded: true, - // Combined, these fragments have length of 1012 octets, which is equal to - // the length of the payload (1004 octets), plus the length of the ICMP - // header (8 octets). - expectedFragmentsForwarded: []fragmentInfo{ - // The first fragment has a length of the greatest multiple of 8 which is - // less than or equal to to `mtu - header.IPv4MinimumSize`. - {offset: 0, payloadSize: uint16(976), more: true}, - // The next fragment holds the rest of the packet. - {offset: uint16(976), payloadSize: 36, more: false}, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - incomingEndpoint := channel.New(1, test.mtu, "") - if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) - } - incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} - if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err) - } - - expectedEmittedPacketCount := 1 - if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount { - expectedEmittedPacketCount = len(test.expectedFragmentsForwarded) - } - outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr) - if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) - } - outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} - if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: incomingIPv4Addr.Subnet(), - NIC: incomingNICID, - }, - { - Destination: outgoingIPv4Addr.Subnet(), - NIC: outgoingNICID, - }, - }) - - if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) - } - - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - icmpHeaderLength := header.ICMPv4MinimumSize - totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength - hdr := buffer.NewPrependable(totalLength) - hdr.Prepend(test.payloadLength) - icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) - icmpH.SetIdent(randomIdent) - icmpH.SetSequence(randomSequence) - icmpH.SetType(header.ICMPv4Echo) - icmpH.SetCode(header.ICMPv4UnusedCode) - icmpH.SetChecksum(0) - icmpH.SetChecksum(^header.Checksum(icmpH, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLength), - Protocol: uint8(header.ICMPv4ProtocolNumber), - TTL: test.TTL, - SrcAddr: test.sourceAddr, - DstAddr: test.destAddr, - Flags: test.ipFlags, - }) - if len(test.options) != 0 { - ip.SetHeaderLength(uint8(ipHeaderLength)) - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - - reply, ok := incomingEndpoint.Read() - - if test.expectErrorICMP { - if !ok { - t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) - } - - // We expect the ICMP packet to contain as much of the original packet as - // possible up to a limit of 576 bytes, split between payload, IP header, - // and ICMP header. - expectedICMPPayloadLength := func() int { - maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize - maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength - if len(hdr.View()) > maxICMPPayloadLength { - return maxICMPPayloadLength - } - return len(hdr.View()) - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(incomingIPv4Addr.Address), - checker.DstAddr(test.sourceAddr), - checker.TTL(ipv4.DefaultTTL), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(test.icmpType), - checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), - ), - ) - } else if ok { - t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) - } - - if test.expectPacketForwarded { - if len(test.expectedFragmentsForwarded) != 0 { - var fragmentedPackets []*stack.PacketBuffer - for i := 0; i < len(test.expectedFragmentsForwarded); i++ { - reply, ok = outgoingEndpoint.Read() - if !ok { - t.Fatal("expected ICMP Echo fragment through outgoing NIC") - } - fragmentedPackets = append(fragmentedPackets, reply.Pkt) - } - - // The forwarded packet's TTL will have been decremented. - ipHeader := header.IPv4(requestPkt.NetworkHeader().View()) - ipHeader.SetTTL(ipHeader.TTL() - 1) - - // Forwarded packets have available header bytes equalling the sum of the - // maximum IP header size and the maximum size allocated for link layer - // headers. In this case, no size is allocated for link layer headers. - expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { - t.Error(err) - } - } else { - reply, ok = outgoingEndpoint.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(test.sourceAddr), - checker.DstAddr(test.destAddr), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) - } - } else { - if reply, ok = outgoingEndpoint.Read(); ok { - t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) - } - } - boolToInt := func(val bool) uint64 { - if val { - return 1 - } - return 0 - } - - if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { - t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { - t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { - t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { - t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) - } - - if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want { - t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) - } - }) - } -} - -// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and -// checks the response. -func TestIPv4Sanity(t *testing.T) { - const ( - ttl = 255 - nicID = 1 - randomSequence = 123 - randomIdent = 42 - // In some cases Linux sets the error pointer to the start of the option - // (offset 0) instead of the actual wrong value, which is the length byte - // (offset 1). For compatibility we must do the same. Use this constant - // to indicate where this happens. - pointerOffsetForInvalidLength = 0 - randomTimeOffset = 0x10203040 - ) - var ( - ipv4Addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - } - remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) - ) - - tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - options header.IPv4Options - replyOptions header.IPv4Options // reply should look like this - shouldFail bool - expectErrorICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - paramProblemPointer uint8 - }{ - { - name: "valid no options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - }, - { - name: "bad header checksum", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - badHeaderChecksum: true, - shouldFail: true, - }, - // The TTL tests check that we are not rejecting an incoming packet - // with a zero or one TTL, which has been a point of confusion in the - // past as RFC 791 says: "If this field contains the value zero, then the - // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies - // for the case of the destination host, stating as follows. - // - // A host MUST NOT send a datagram with a Time-to-Live (TTL) - // value of zero. - // - // A host MUST NOT discard a datagram just because it was - // received with TTL less than 2. - { - name: "zero TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 0, - }, - { - name: "one TTL", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 1, - }, - { - name: "End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{0, 0, 0, 0}, - replyOptions: header.IPv4Options{0, 0, 0, 0}, - }, - { - name: "NOP options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 1, 1}, - replyOptions: header.IPv4Options{1, 1, 1, 1}, - }, - { - name: "NOP and End options", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{1, 1, 0, 0}, - replyOptions: header.IPv4Options{1, 1, 0, 0}, - }, - { - name: "bad header length", - headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (0)", - maxTotalLength: 0, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad total length (ip + icmp - 1)", - maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - shouldFail: true, - }, - { - name: "bad protocol", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: 99, - TTL: ttl, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4DstUnreachable, - ICMPCode: header.ICMPv4ProtoUnreachable, - }, - { - name: "timestamp option overflow", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - }, - { - name: "timestamp option overflow full", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0xF1, - // ^ Counter full (15/0xF) - 192, 168, 1, 12, - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - replyOptions: header.IPv4Options{}, - }, - { - name: "unknown option", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{10, 4, 9, 0}, - // ^^ - // The unknown option should be stripped out of the reply. - replyOptions: header.IPv4Options{}, - }, - { - name: "bad option - no length", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 1, 1, 1, 68, - // ^-start of timestamp.. but no length.. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 3, - }, - { - name: "bad option - length 0", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 0, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 9, 0, - // ^ - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "bad option - length big", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 9, 9, 0, - // ^ - // There are only 8 bytes allocated to options so 9 bytes of timestamp - // space is not possible. (Second byte) - 1, 2, 3, 4, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - // This tests for some linux compatible behaviour. - // The ICMP pointer returned is 22 for Linux but the - // error is actually in spot 21. - name: "bad option - length bad", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - // Timestamps are in multiples of 4 or 8 but never 7. - // The option space should be padded out. - options: header.IPv4Options{ - 68, 7, 5, 0, - // ^ ^ Linux points here which is wrong. - // | Not a multiple of 4 - 1, 2, 3, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "multiple type 0 with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 24, 21, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 24, 25, 0x00, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // The timestamp area is full so add to the overflow count. - name: "multiple type 1 timestamps", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 21, 0x11, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - // Overflow count is the top nibble of the 4th byte. - replyOptions: header.IPv4Options{ - 68, 20, 21, 0x21, - // ^ - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - }, - }, - { - name: "multiple type 1 timestamps with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 28, 21, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 28, 29, 0x01, - 192, 168, 1, 12, - 1, 2, 3, 4, - 192, 168, 1, 13, - 5, 6, 7, 8, - 192, 168, 1, 58, // New IP Address. - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Timestamp pointer uses one based counting so 0 is invalid. - name: "timestamp pointer invalid", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, 0, 0x00, - // ^ 0 instead of 5 or more. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - }, - { - // Timestamp pointer cannot be less than 5. It must point past the header - // which is 4 bytes. (1 based counting) - name: "timestamp pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, - // ^ header is 4 bytes, so 4 should fail. - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - { - name: "valid timestamp pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, - // ^ header is 4 bytes, so 5 should succeed. - 0, 0, 0, 0, - }, - replyOptions: header.IPv4Options{ - 68, 8, 9, 0x00, - 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock - }, - }, - { - // Needs 8 bytes for a type 1 timestamp but there are only 4 free. - name: "bad timer element alignment", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 20, 17, 0x01, - // ^^ ^^ 20 byte area, next free spot at 17. - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 0, 0, 0, - 0, 0, 0, 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, - }, - // End of option list with illegal option after it, which should be ignored. - { - name: "end of options list", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 12, 13, 0x11, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, 10, 3, 99, // EOL followed by junk - }, - replyOptions: header.IPv4Options{ - 68, 12, 13, 0x21, - 192, 168, 1, 12, - 1, 2, 3, 4, - 0, // End of Options hides following bytes. - 0, 0, 0, // 3 bytes unknown option removed. - }, - }, - { - // Timestamp with a size much too small. - name: "timestamp truncated", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 68, 1, 0, 0, - // ^ Smallest possible is 8. Linux points at the 68. - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, - }, - { - name: "single record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 4, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "multiple record route with room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 20, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - name: "single record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Unlike timestamp, this should just succeed. - name: "multiple record route with no room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 23, 24, // 3 byte header - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 23, 24, - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Pointer uses one based counting so 0 is invalid. - name: "record route pointer zero", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 0, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. 3 should fail. - name: "record route pointer too small by 1", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header - 0, 0, 0, 0, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - // Pointer must be 4 or more as it must point past the 3 byte header - // using 1 based counting. Check 4 passes. (Duplicates "single - // record route with room") - name: "valid record route pointer", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header - 0, 0, 0, 0, - 0, - }, - replyOptions: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 192, 168, 1, 58, // New IP Address. - 0, // padding to multiple of 4 bytes. - }, - }, - { - // Confirm Linux bug for bug compatibility. - // Linux returns slot 22 but the error is in slot 21. - name: "multiple record route with not enough room", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 8, 8, // 3 byte header - // ^ ^ Linux points here. We must too. - // | Not enough room. 1 byte free, need 4. - 1, 2, 3, 4, - 0, - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, - }, - { - name: "duplicate record route", - maxTotalLength: ipv4.MaxTotalSize, - transportProtocol: uint8(header.ICMPv4ProtocolNumber), - TTL: ttl, - options: header.IPv4Options{ - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 7, 7, 8, // 3 byte header - 1, 2, 3, 4, - 0, 0, // pad - }, - shouldFail: true, - expectErrorICMP: true, - ICMPType: header.ICMPv4ParamProblem, - ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 7, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - Clock: clock, - }) - // Advance the clock by some unimportant amount to make - // it give a more recognisable signature than 00,00,00,00. - clock.Advance(time.Millisecond * randomTimeOffset) - - // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) - } - - // Default routes for IPv4 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - if len(test.options)%4 != 0 { - t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) - } - ipHeaderLength := header.IPv4MinimumSize + len(test.options) - if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) - } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - - // Specify ident/seq to make sure we get the same in the response. - icmpH.SetIdent(randomIdent) - icmpH.SetSequence(randomSequence) - icmpH.SetType(header.ICMPv4Echo) - icmpH.SetCode(header.ICMPv4UnusedCode) - icmpH.SetChecksum(0) - icmpH.SetChecksum(^header.Checksum(icmpH, 0)) - ip := header.IPv4(hdr.Prepend(ipHeaderLength)) - if test.maxTotalLength < totalLen { - totalLen = test.maxTotalLength - } - ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, - Protocol: test.transportProtocol, - TTL: test.TTL, - SrcAddr: remoteIPv4Addr, - DstAddr: ipv4Addr.Address, - }) - if test.headerLength != 0 { - ip.SetHeaderLength(test.headerLength) - } else { - // Set the calculated header length, since we may manually add options. - ip.SetHeaderLength(uint8(ipHeaderLength)) - } - if len(test.options) != 0 { - // Copy options manually. We do not use Encode for options so we can - // verify malformed options with handcrafted payloads. - if want, got := copy(ip.Options(), test.options), len(test.options); want != got { - t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) - } - } - ip.SetChecksum(0) - ipHeaderChecksum := ip.CalculateChecksum() - if test.badHeaderChecksum { - ipHeaderChecksum += 42 - } - ip.SetChecksum(^ipHeaderChecksum) - requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) - reply, ok := e.Read() - if !ok { - if test.shouldFail { - if test.expectErrorICMP { - t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) - } - return // Expected silent failure. - } - t.Fatal("expected ICMP echo reply missing") - } - - // We didn't expect a packet. Register our surprise but carry on to - // provide more information about what we got. - if test.shouldFail && !test.expectErrorICMP { - t.Error("unexpected packet response") - } - - // Check the route that brought the packet to us. - if reply.Route.LocalAddress != ipv4Addr.Address { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) - } - if reply.Route.RemoteAddress != remoteIPv4Addr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) - } - - // Make sure it's all in one buffer for checker. - replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - - // At this stage we only know it's probably an IP+ICMP header so verify - // that much. - checker.IPv4(t, replyIPHeader, - checker.SrcAddr(ipv4Addr.Address), - checker.DstAddr(remoteIPv4Addr), - checker.ICMPv4( - checker.ICMPv4Checksum(), - ), - ) - - // Don't proceed any further if the checker found problems. - if t.Failed() { - t.FailNow() - } - - // OK it's ICMP. We can safely look at the type now. - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - switch replyICMPHeader.Type() { - case header.ICMPv4ParamProblem: - if !test.shouldFail { - t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) - } - if !test.expectErrorICMP { - t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload(hdr.View()), - ), - ) - return - case header.ICMPv4DstUnreachable: - if !test.shouldFail { - t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - if !test.expectErrorICMP { - t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", - header.ICMPv4DstUnreachable, replyICMPHeader.Code()) - } - checker.IPv4(t, replyIPHeader, - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(test.ICMPType), - checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload(hdr.View()), - ), - ) - return - case header.ICMPv4EchoReply: - if test.shouldFail { - if !test.expectErrorICMP { - t.Error("got Echo Reply packet, want no response") - } else { - t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) - } - } - // If the IP options change size then the packet will change size, so - // some IP header fields will need to be adjusted for the checks. - sizeChange := len(test.replyOptions) - len(test.options) - - checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength+sizeChange), - checker.IPv4Options(test.replyOptions), - checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Seq(randomSequence), - checker.ICMPv4Ident(randomIdent), - ), - ) - default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) - } - }) - } -} - -// compareFragments compares the contents of a set of fragmented packets against -// the contents of a source packet. -// -// If withIPHeader is set to true, we will validate the fragmented packets' IP -// headers against the source packet's IP header. If set to false, we validate -// the fragmented packets' IP headers against each other. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error { - // Make a complete array of the sourcePacket packet. - var source header.IPv4 - vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - - // If the packet to be fragmented contains an IPv4 header, use that header for - // validating fragment headers. Else, use the header of the first fragment. - if withIPHeader { - source = header.IPv4(vv.ToView()) - } else { - source = header.IPv4(packets[0].NetworkHeader().View()) - source = append(source, vv.ToView()...) - } - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - // Build up an array of the bytes sent. - var reassembledPayload buffer.VectorisedView - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - fragmentIPHeader := header.IPv4(allBytes.ToView()) - if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { - return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) - } - if got := len(fragmentIPHeader); got > int(mtu) { - return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) - } - if got := fragmentIPHeader.TransportProtocol(); got != proto { - return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) - } - if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { - return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) - } - if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes) - } - if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { - return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) - } - if wantFragments[i].more { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) - } - reassembledPayload.AppendView(packet.TransportHeader().View()) - reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView()) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - - // If we are validating against the original IP header, we should exclude the - // ID field, which will only be set fo fragmented packets. - if withIPHeader { - fragmentIPHeader.SetID(0) - fragmentIPHeader.SetChecksum(0) - fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum()) - } - if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) - } - } - - expected := buffer.View(source[source.HeaderLength():]) - if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { - return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) - } - - return nil -} - -type fragmentInfo struct { - offset uint16 - more bool - payloadSize uint16 -} - -var fragmentationTests = []struct { - description string - mtu uint32 - transportHeaderLength int - payloadSize int - wantFragments []fragmentInfo -}{ - { - description: "No fragmentation", - mtu: 1280, - transportHeaderLength: 0, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1000, more: false}, - }, - }, - { - description: "Fragmented", - mtu: 1280, - transportHeaderLength: 0, - payloadSize: 2000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 744, more: false}, - }, - }, - { - description: "Fragmented with the minimum mtu", - mtu: header.IPv4MinimumMTU, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "Fragmented with mtu not a multiple of 8", - mtu: header.IPv4MinimumMTU + 1, - transportHeaderLength: 0, - payloadSize: 100, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 48, more: true}, - {offset: 48, payloadSize: 48, more: true}, - {offset: 96, payloadSize: 4, more: false}, - }, - }, - { - description: "No fragmentation with big header", - mtu: 2000, - transportHeaderLength: 100, - payloadSize: 1000, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1100, more: false}, - }, - }, - { - description: "Fragmented with big header", - mtu: 1280, - transportHeaderLength: 100, - payloadSize: 1200, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 44, more: false}, - }, - }, - { - description: "Fragmented with MTU smaller than header", - mtu: 300, - transportHeaderLength: 1000, - payloadSize: 500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 280, more: true}, - {offset: 280, payloadSize: 280, more: true}, - {offset: 560, payloadSize: 280, more: true}, - {offset: 840, payloadSize: 280, more: true}, - {offset: 1120, payloadSize: 280, more: true}, - {offset: 1400, payloadSize: 100, more: false}, - }, - }, -} - -func TestFragmentationWritePacket(t *testing.T) { - const ttl = 42 - - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - source := pkt.Clone() - err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if err != nil { - t.Fatalf("r.WritePacket(...): %s", err) - } - if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) - } - if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { - t.Error(err) - } - }) - } -} - -func TestFragmentationWritePackets(t *testing.T) { - const ttl = 42 - writePacketsTests := []struct { - description string - insertBefore int - insertAfter int - }{ - { - description: "Single packet", - insertBefore: 0, - insertAfter: 0, - }, - { - description: "With packet before", - insertBefore: 1, - insertAfter: 0, - }, - { - description: "With packet after", - insertBefore: 0, - insertAfter: 1, - }, - { - description: "With packet before and after", - insertBefore: 1, - insertAfter: 1, - }, - } - tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) - - for _, test := range writePacketsTests { - t.Run(test.description, func(t *testing.T) { - for _, ft := range fragmentationTests { - t.Run(ft.description, func(t *testing.T) { - var pkts stack.PacketBufferList - for i := 0; i < test.insertBefore; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - pkts.PushBack(pkt.Clone()) - for i := 0; i < test.insertAfter; i++ { - pkts.PushBack(tinyPacket.Clone()) - } - - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) - r := buildRoute(t, ep) - - wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }) - if err != nil { - t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) - } - if n != wantTotalPackets { - t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) - } - if got := len(ep.WrittenPackets); got != wantTotalPackets { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { - t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) - } - - if wantTotalPackets == 0 { - return - } - - fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { - t.Error(err) - } - }) - } - }) - } -} - -// TestFragmentationErrors checks that errors are returned from WritePacket -// correctly. -func TestFragmentationErrors(t *testing.T) { - const ttl = 42 - - tests := []struct { - description string - mtu uint32 - transportHeaderLength int - payloadSize int - allowPackets int - outgoingErrors int - mockError tcpip.Error - wantError tcpip.Error - }{ - { - description: "No frag", - mtu: 2000, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 1, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 0, - outgoingErrors: 3, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on second frag", - mtu: 500, - payloadSize: 1000, - transportHeaderLength: 0, - allowPackets: 1, - outgoingErrors: 2, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error on first frag MTU smaller than header", - mtu: 500, - transportHeaderLength: 1000, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 4, - mockError: &tcpip.ErrAborted{}, - wantError: &tcpip.ErrAborted{}, - }, - { - description: "Error when MTU is smaller than IPv4 minimum MTU", - mtu: header.IPv4MinimumMTU - 1, - transportHeaderLength: 0, - payloadSize: 500, - allowPackets: 0, - outgoingErrors: 1, - mockError: nil, - wantError: &tcpip.ErrInvalidEndpointState{}, - }, - } - - for _, ft := range tests { - t.Run(ft.description, func(t *testing.T) { - pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) - r := buildRoute(t, ep) - err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: tcp.ProtocolNumber, - TTL: ttl, - TOS: stack.DefaultTOS, - }, pkt) - if diff := cmp.Diff(ft.wantError, err); diff != "" { - t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) - } - if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { - t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) - } - if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) - } - }) - } -} - -func TestInvalidFragments(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = tcpip.Address("\x0a\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x02") - tos = 0 - ident = 1 - ttl = 48 - protocol = 6 - ) - - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload - } - - type fragmentData struct { - ipv4fields header.IPv4Fields - // 0 means insert the correct IHL. Non 0 means override the correct IHL. - overrideIHL int // For 0 use 1 as it is an int and will be divided by 4. - payload []byte - autoChecksum bool // If true, the Checksum field will be overwritten. - } - - tests := []struct { - name string - fragments []fragmentData - wantMalformedIPPackets uint64 - wantMalformedFragments uint64 - }{ - { - name: "IHL and TotalLength zero, FragmentOffset non-zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, - FragmentOffset: 59776, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - name: "IHL and TotalLength zero, FragmentOffset zero", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: 0, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - overrideIHL: 1, // See note above. - payload: payloadGen(12), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 0, - }, - { - // Payload 17 octets and Fragment offset 65520 - // Leading to the fragment end to be past 65536. - name: "fragment ends past 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 17, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(17), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - // Payload 16 octets and fragment offset 65520 - // Leading to the fragment end to be exactly 65536. - name: "fragment ends exactly at 65536", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: 0, - FragmentOffset: 65520, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(16), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 0, - wantMalformedFragments: 0, - }, - { - name: "IHL less than IPv4 minimum size", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 1944, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize - 12, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize - 12, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 2, - wantMalformedFragments: 0, - }, - { - name: "fragment with short TotalLength and extra payload", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 28, - ID: ident, - Flags: 0, - FragmentOffset: 28816, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 4, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(28), - overrideIHL: header.IPv4MinimumSize + 4, - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - { - name: "multiple fragments with More Fragments flag set to false", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 128, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: payloadGen(8), - autoChecksum: true, - }, - }, - wantMalformedIPPackets: 1, - wantMalformedFragments: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - }) - e := channel.New(0, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: addr2.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize + len(f.payload) - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got { - t.Fatalf("copied %d bytes, expected %d bytes.", got, want) - } - // Encode sets this up correctly. If we want a different value for - // testing then we need to overwrite the good value. - if f.overrideIHL != 0 { - ip.SetHeaderLength(uint8(f.overrideIHL)) - // If we are asked to add options (type not specified) then pad - // with 0 (EOL). RFC 791 page 23 says "The padding is zero". - for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ { - ip[i] = byte(header.IPv4OptionListEndType) - } - } - - if f.autoChecksum { - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - } - - vv := hdr.View().ToVectorisedView() - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) - } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) - } - }) - } -} - -func TestFragmentReassemblyTimeout(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - addr1 = tcpip.Address("\x0a\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x02") - tos = 0 - ident = 1 - ttl = 48 - protocol = 99 - data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" - ) - - type fragmentData struct { - ipv4fields header.IPv4Fields - payload []byte - } - - tests := []struct { - name string - fragments []fragmentData - expectICMP bool - }{ - { - name: "first fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "two first fragments", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 16, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:16], - }, - }, - expectICMP: true, - }, - { - name: "second fragment only", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 8, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: false, - }, - { - name: "two fragments with a gap", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - }, - expectICMP: true, - }, - { - name: "two fragments with a gap in reverse order", - fragments: []fragmentData{ - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), - ID: ident, - Flags: 0, - FragmentOffset: 16, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[16:], - }, - { - ipv4fields: header.IPv4Fields{ - TOS: tos, - TotalLength: header.IPv4MinimumSize + 8, - ID: ident, - Flags: header.IPv4FlagMoreFragments, - FragmentOffset: 0, - TTL: ttl, - Protocol: protocol, - SrcAddr: addr1, - DstAddr: addr2, - }, - payload: []byte(data)[:8], - }, - }, - expectICMP: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - }, - Clock: clock, - }) - e := channel.New(1, 1500, linkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: addr2.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }}) - - var firstFragmentSent buffer.View - for _, f := range test.fragments { - pktSize := header.IPv4MinimumSize - hdr := buffer.NewPrependable(pktSize) - - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&f.ipv4fields) - - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(f.payload) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - - if firstFragmentSent == nil && ip.FragmentOffset() == 0 { - firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) - } - - e.InjectInbound(header.IPv4ProtocolNumber, pkt) - } - - clock.Advance(ipv4.ReassembleTimeout) - - reply, ok := e.Read() - if !test.expectICMP { - if ok { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - return - } - if !ok { - t.Fatal("expected ICMP error message missing") - } - if firstFragmentSent == nil { - t.Fatalf("unexpected ICMP error message received: %#v", reply) - } - - checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), - checker.SrcAddr(addr2), - checker.DstAddr(addr1), - checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), - checker.IPv4HeaderLength(header.IPv4MinimumSize), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4TimeExceeded), - checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), - checker.ICMPv4Checksum(), - checker.ICMPv4Payload(firstFragmentSent), - ), - ) - }) - } -} - -// TestReceiveFragments feeds fragments in through the incoming packet path to -// test reassembly -func TestReceiveFragments(t *testing.T) { - const ( - nicID = 1 - - addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1 - addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2 - addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3 - ) - - // Build and return a UDP header containing payload. - udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { - payload := buffer.NewView(payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) * multiplier - } - - udpLength := header.UDPMinimumSize + len(payload) - - hdr := buffer.NewPrependable(udpLength) - u := header.UDP(hdr.Prepend(udpLength)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: uint16(udpLength), - }) - copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) - sum = header.Checksum(payload, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - return hdr.View() - } - - // UDP header plus a payload of 0..256 - ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) - udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] - ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) - udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 2. - ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) - udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] - // UDP header plus a payload of 0..256 in increments of 3. - // Used to test cases where the fragment blocks are not a multiple of - // the fragment block size of 8 (RFC 791 section 3.1 page 14). - ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) - udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled IPv4 payload length. - ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) - udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] - - type fragmentData struct { - srcAddr tcpip.Address - dstAddr tcpip.Address - id uint16 - flags uint8 - fragmentOffset uint16 - payload buffer.View - } - - tests := []struct { - name string - fragments []fragmentData - expectedPayloads [][]byte - }{ - { - name: "No fragmentation", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "No fragmentation with size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2, - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "More fragments without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Non-zero fragment offset without payload", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 8, - payload: ipv4Payload1Addr1ToAddr2, - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments out of order", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, - }, - { - name: "Two fragments with last fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload3Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, - }, - { - name: "Two fragments with first fragment size not a multiple of fragment block size", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload3Addr1ToAddr2[:63], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 63, - payload: ipv4Payload3Addr1ToAddr2[63:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Second fragment has MoreFlags set", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments with different IDs", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two interleaved fragmented packets", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload2Addr1ToAddr2[:64], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 2, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload2Addr1ToAddr2[64:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, - }, - { - name: "Two interleaved fragmented packets from different sources but with same ID", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr3ToAddr2[:32], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 64, - payload: ipv4Payload1Addr1ToAddr2[64:], - }, - { - srcAddr: addr3, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 32, - payload: ipv4Payload1Addr3ToAddr2[32:], - }, - }, - expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, - }, - { - name: "Fragment without followup", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload1Addr1ToAddr2[:64], - }, - }, - expectedPayloads: nil, - }, - { - name: "Two fragments reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: 0, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, - }, - { - name: "Two fragments with MF flag reassembled into a maximum UDP packet", - fragments: []fragmentData{ - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 0, - payload: ipv4Payload4Addr1ToAddr2[:65512], - }, - { - srcAddr: addr1, - dstAddr: addr2, - id: 1, - flags: header.IPv4FlagMoreFragments, - fragmentOffset: 65512, - payload: ipv4Payload4Addr1ToAddr2[65512:], - }, - }, - expectedPayloads: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Setup a stack and endpoint. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - RawFactory: raw.EndpointFactory{}, - }) - e := channel.New(0, 1280, "\xf0\x00") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: addr2.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - // Bring up a raw endpoint so we can examine network headers. - epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) - if err != nil { - t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) - } - defer epRaw.Close() - - // Prepare and send the fragments. - for _, frag := range test.fragments { - hdr := buffer.NewPrependable(header.IPv4MinimumSize) - - // Serialize IPv4 fixed header. - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), - ID: frag.id, - Flags: frag.flags, - FragmentOffset: frag.fragmentOffset, - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: frag.srcAddr, - DstAddr: frag.dstAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - vv := hdr.View().ToVectorisedView() - vv.AppendView(frag.payload) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } - - if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { - t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) - } - - for i, expectedPayload := range test.expectedPayloads { - // Check UDP payload delivered by UDP endpoint. - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) ep.Read: %s", i, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(expectedPayload), - Total: len(expectedPayload), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) - } - if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) - } - - // Check IPv4 header in packet delivered by raw endpoint. - buf.Reset() - result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("(i=%d) epRaw.Read: %s", i, err) - } - // Reassambly does not take care of checksum. Here we write our own - // check routine instead of using checker.IPv4. - ip := header.IPv4(buf.Bytes()) - for _, check := range []checker.NetworkChecker{ - checker.FragmentFlags(0), - checker.FragmentOffset(0), - checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), - } { - check(t, []header.Network{ip}) - } - } - - res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) - } - }) - } -} - -func TestWriteStats(t *testing.T) { - const nPackets = 3 - - tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectOutputDropped int - expectPostroutingDropped int - expectWritten int - }{ - { - name: "Accept all", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectOutputDropped: 0, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Accept all with error", - // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectOutputDropped: 0, - expectPostroutingDropped: 0, - expectWritten: nPackets - 1, - }, { - name: "Drop all with Output chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectOutputDropped: nPackets, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Drop all with Postrouting chain", - setup: func(t *testing.T, stk *stack.Stack) { - ipt := stk.IPTables() - filter := ipt.GetTable(stack.NATID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectOutputDropped: 0, - expectPostroutingDropped: nPackets, - expectWritten: nPackets, - }, { - name: "Drop some with Output chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Output DROP rule that matches only 1 - // of the 3 packets. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectOutputDropped: 1, - expectPostroutingDropped: 0, - expectWritten: nPackets, - }, { - name: "Drop some with Postrouting chain", - setup: func(t *testing.T, stk *stack.Stack) { - // Install Postrouting DROP rule that matches only 1 - // of the 3 packets. - ipt := stk.IPTables() - filter := ipt.GetTable(stack.NATID, false /* ipv6 */) - // We'll match and DROP the last packet. - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} - // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %s", err) - } - }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectOutputDropped: 0, - expectPostroutingDropped: 1, - expectWritten: nPackets, - }, - } - - // Parameterize the tests to run with both WritePacket and WritePackets. - writers := []struct { - name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) - }{ - { - name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - nWritten := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { - return nWritten, err - } - nWritten++ - } - return nWritten, nil - }, - }, { - name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) - }, - }, - } - - for _, writer := range writers { - t.Run(writer.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) - rt := buildRoute(t, ep) - - var pkts stack.PacketBufferList - for i := 0; i < nPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), - Data: buffer.NewView(0).ToVectorisedView(), - }) - pkt.TransportHeader().Push(header.UDPMinimumSize) - pkts.PushBack(pkt) - } - - test.setup(t, rt.Stack()) - - nWritten, _ := writer.writePackets(rt, pkts) - - if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) - } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { - t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) - } - if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { - t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) - } - if nWritten != test.expectWritten { - t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) - } - }) - } - }) - } -} - -func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC(1, _) failed: %s", err) - } - const ( - src = tcpip.Address("\x10\x00\x00\x01") - dst = tcpip.Address("\x10\x00\x00\x02") - ) - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: src.WithPrefix(), - } - if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) - } - { - mask := tcpip.AddressMask(header.IPv4Broadcast) - subnet, err := tcpip.NewSubnet(dst, mask) - if err != nil { - t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) - } - return rt -} - -// limitedMatcher is an iptables matcher that matches after a certain number of -// packets are checked against it. -type limitedMatcher struct { - limit int -} - -// Name implements Matcher.Name. -func (*limitedMatcher) Name() string { - return "limitedMatcher" -} - -// Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { - if lm.limit == 0 { - return true, false - } - lm.limit-- - return false, false -} - -func TestPacketQueuing(t *testing.T) { - const nicID = 1 - - var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - - host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 8, - }, - } - ) - - tests := []struct { - name string - rxPkt func(*channel.Endpoint) - checkResp func(*testing.T, *channel.Endpoint) - }{ - { - name: "ICMP Error", - rxPkt: func(e *channel.Endpoint) { - hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) - u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - u.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: 80, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(nil, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, - TTL: ipv4.DefaultTTL, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - }, - }, - - { - name: "Ping", - rxPkt: func(e *channel.Endpoint) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ipv4.DefaultTTL, - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - }, - checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != header.IPv4ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) - } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply), - checker.ICMPv4Code(header.ICMPv4UnusedCode))) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(1, defaultMTU, host1NICLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: clock, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - - // Receive a packet to trigger link resolution before a response is sent. - test.rxPkt(e) - - // Wait for a ARP request since link address resolution should be - // performed. - { - clock.RunImmediatelyScheduledJobs() - p, ok := e.Read() - if !ok { - t.Fatalf("timed out waiting for packet") - } - if p.Proto != arp.ProtocolNumber { - t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) - } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - rep := header.ARP(p.Pkt.NetworkHeader().View()) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) - } - } - - // Send an ARP reply to complete link address resolution. - { - hdr := buffer.View(make([]byte, header.ARPSize)) - packet := header.ARP(hdr) - packet.SetIPv4OverEthernet() - packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), host2NICLinkAddr) - copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) - copy(packet.HardwareAddressTarget(), host1NICLinkAddr) - copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) - e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.ToVectorisedView(), - })) - } - - // Expect the response now that the link address has resolved. - clock.RunImmediatelyScheduledJobs() - test.checkResp(t, e) - - // Since link resolution was already performed, it shouldn't be performed - // again. - test.rxPkt(e) - test.checkResp(t, e) - }) - } -} - -// TestCloseLocking test that lock ordering is followed when closing an -// endpoint. -func TestCloseLocking(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - iterations = 1000 - ) - - var ( - src = testutil.MustParse4("16.0.0.1") - dst = testutil.MustParse4("16.0.0.2") - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - // Perform NAT so that the endpoint tries to search for a sibling endpoint - // which ends up taking the protocol and endpoint lock (in that order). - table := stack.Table{ - Rules: []stack.Rule{ - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 1, - stack.Forward: stack.HookUnset, - stack.Output: 2, - stack.Postrouting: 3, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 1, - stack.Forward: stack.HookUnset, - stack.Output: 2, - stack.Postrouting: 3, - }, - } - if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { - t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) - } - - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: src.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }}) - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} - if err := ep.Connect(addr); err != nil { - t.Errorf("ep.Connect(%#v): %s", addr, err) - } - - var wg sync.WaitGroup - defer wg.Wait() - - // Writing packets should trigger NAT which requires the stack to search the - // protocol for network endpoints with the destination address. - // - // Creating and removing interfaces should modify the protocol and endpoint - // which requires taking the locks of each. - // - // We expect the protocol > endpoint lock ordering to be followed here. - wg.Add(2) - go func() { - defer wg.Done() - - data := []byte{1, 2, 3, 4} - - for i := 0; i < iterations; i++ { - var r bytes.Reader - r.Reset(data) - if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Errorf("ep.Write(_, _): %s", err) - return - } else if want := int64(len(data)); n != want { - t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) - return - } - } - }() - go func() { - defer wg.Done() - - for i := 0; i < iterations; i++ { - if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil { - t.Errorf("CreateNIC(%d, _): %s", nicID2, err) - return - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Errorf("RemoveNIC(%d): %s", nicID2, err) - return - } - } - }() -} - -func TestIcmpRateLimit(t *testing.T) { - var ( - host1IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), - PrefixLen: 24, - }, - } - host2IPv4Addr = tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), - PrefixLen: 24, - }, - } - ) - const icmpBurst = 5 - e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - Clock: faketime.NewManualClock(), - }) - s.SetICMPBurst(icmpBurst) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: nicID, - }, - }) - tests := []struct { - name string - createPacket func() buffer.View - check func(*testing.T, *channel.Endpoint, int) - }{ - { - name: "echo", - createPacket: func() buffer.View { - totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLength) - icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - icmpH.SetIdent(1) - icmpH.SetSequence(1) - icmpH.SetType(header.ICMPv4Echo) - icmpH.SetCode(header.ICMPv4UnusedCode) - icmpH.SetChecksum(0) - icmpH.SetChecksum(^header.Checksum(icmpH, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLength), - Protocol: uint8(header.ICMPv4ProtocolNumber), - TTL: 1, - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - return hdr.View() - }, - check: func(t *testing.T, e *channel.Endpoint, round int) { - p, ok := e.Read() - if !ok { - t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) - } - if got, want := p.Proto, header.IPv4ProtocolNumber; got != want { - t.Errorf("got p.Proto = %d, want = %d", got, want) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply), - )) - }, - }, - { - name: "dst unreachable", - createPacket: func() buffer.View { - totalLength := header.IPv4MinimumSize + header.UDPMinimumSize - hdr := buffer.NewPrependable(totalLength) - udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) - udpH.Encode(&header.UDPFields{ - SrcPort: 100, - DstPort: 101, - Length: header.UDPMinimumSize, - }) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLength), - Protocol: uint8(header.UDPProtocolNumber), - TTL: 1, - SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, - DstAddr: host1IPv4Addr.AddressWithPrefix.Address, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - return hdr.View() - }, - check: func(t *testing.T, e *channel.Endpoint, round int) { - p, ok := e.Read() - if round >= icmpBurst { - if ok { - t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) - } - return - } - if !ok { - t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) - } - checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - )) - }, - }, - } - for _, testCase := range tests { - t.Run(testCase.name, func(t *testing.T) { - for round := 0; round < icmpBurst+1; round++ { - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: testCase.createPacket().ToVectorisedView(), - })) - testCase.check(t, e, round) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go deleted file mode 100644 index d1f9e3cf5..000000000 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4 - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface - nicID tcpip.NICID -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEP := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - var nic testInterface - ep := proto.NewEndpoint(&nic, nil).(*endpoint) - // At this point, the Stack's stats and the NetworkEndpoint's stats are - // expected to be bound by a MultiCounterStat. - refStack := s.Stats() - refEP := ep.stats.localStats - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { - t.Error(err) - } - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { - t.Error(err) - } -} |