// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ipv4_test import ( "bytes" "context" "encoding/hex" "fmt" "io/ioutil" "math" "net" "testing" "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) const ( extraHeaderReserve = 50 defaultMTU = 65536 ) func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) } if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: 1, }}) randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} var wq waiter.Queue t.Run("WithoutPrimaryAddress", func(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatal(err) } defer ep.Close() // Cannot connect using a broadcast address as the source. { err := ep.Connect(randomAddr) if _, ok := err.(*tcpip.ErrNoRoute); !ok { t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) } } // However, we can bind to a broadcast address to listen. if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { t.Errorf("Bind failed: %v", err) } }) t.Run("WithPrimaryAddress", func(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatal(err) } defer ep.Close() // Add a valid primary endpoint address, now we can connect. if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { t.Fatalf("AddAddress failed: %v", err) } if err := ep.Connect(randomAddr); err != nil { t.Errorf("Connect failed: %v", err) } }) } func TestForwarding(t *testing.T) { const ( nicID1 = 1 nicID2 = 2 randomSequence = 123 randomIdent = 42 randomTimeOffset = 0x10203040 ) ipv4Addr1 := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, } ipv4Addr2 := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), PrefixLen: 8, } remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) tests := []struct { name string TTL uint8 expectErrorICMP bool options header.IPv4Options forwardedOptions header.IPv4Options icmpType header.ICMPv4Type icmpCode header.ICMPv4Code }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, icmpType: header.ICMPv4TimeExceeded, icmpCode: header.ICMPv4TTLExceeded, }, { name: "TTL of one", TTL: 1, expectErrorICMP: false, }, { name: "TTL of two", TTL: 2, expectErrorICMP: false, }, { name: "Max TTL", TTL: math.MaxUint8, expectErrorICMP: false, }, { name: "four EOL options", TTL: 2, expectErrorICMP: false, options: header.IPv4Options{0, 0, 0, 0}, forwardedOptions: header.IPv4Options{0, 0, 0, 0}, }, { name: "TS type 1 full", TTL: 2, options: header.IPv4Options{ 68, 12, 13, 0xF1, 192, 168, 1, 12, 1, 2, 3, 4, }, expectErrorICMP: true, icmpType: header.ICMPv4ParamProblem, icmpCode: header.ICMPv4UnusedCode, }, { name: "TS type 0", TTL: 2, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 0, 0, }, forwardedOptions: header.IPv4Options{ 68, 24, 25, 0x00, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, }, { name: "end of options list", TTL: 2, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, 1, 2, 3, 4, 0, 10, 3, 99, // EOL followed by junk 1, 2, 3, 4, }, forwardedOptions: header.IPv4Options{ 68, 12, 13, 0x21, 192, 168, 1, 12, 1, 2, 3, 4, 0, // End of Options hides following bytes. 0, 0, 0, // 7 bytes unknown option removed. 0, 0, 0, 0, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, Clock: clock, }) // Advance the clock by some unimportant amount to make // it give a more recognisable signature than 00,00,00,00. clock.Advance(time.Millisecond * randomTimeOffset) // We expect at most a single packet in response to our ICMP Echo Request. e1 := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID1, e1); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) } e2 := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID2, e2); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) } s.SetRouteTable([]tcpip.Route{ { Destination: ipv4Addr1.Subnet(), NIC: nicID1, }, { Destination: ipv4Addr2.Subnet(), NIC: nicID2, }, }) if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv4Echo) icmp.SetCode(header.ICMPv4UnusedCode) icmp.SetChecksum(0) icmp.SetChecksum(^header.Checksum(icmp, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: uint8(header.ICMPv4ProtocolNumber), TTL: test.TTL, SrcAddr: remoteIPv4Addr1, DstAddr: remoteIPv4Addr2, }) if len(test.options) != 0 { ip.SetHeaderLength(uint8(ipHeaderLength)) // Copy options manually. We do not use Encode for options so we can // verify malformed options with handcrafted payloads. if want, got := copy(ip.Options(), test.options), len(test.options); want != got { t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) } } ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) if test.expectErrorICMP { reply, ok := e1.Read() if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(ipv4Addr1.Address), checker.DstAddr(remoteIPv4Addr1), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), checker.ICMPv4Payload([]byte(hdr.View())), ), ) if n := e2.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } } else { reply, ok := e2.Read() if !ok { t.Fatal("expected ICMP Echo packet through outgoing NIC") } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(remoteIPv4Addr1), checker.DstAddr(remoteIPv4Addr2), checker.TTL(test.TTL-1), checker.IPv4Options(test.forwardedOptions), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(header.ICMPv4Echo), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Payload(nil), ), ) if n := e1.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } } }) } } // TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and // checks the response. func TestIPv4Sanity(t *testing.T) { const ( ttl = 255 nicID = 1 randomSequence = 123 randomIdent = 42 // In some cases Linux sets the error pointer to the start of the option // (offset 0) instead of the actual wrong value, which is the length byte // (offset 1). For compatibility we must do the same. Use this constant // to indicate where this happens. pointerOffsetForInvalidLength = 0 randomTimeOffset = 0x10203040 ) var ( ipv4Addr = tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), PrefixLen: 24, } remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) ) tests := []struct { name string headerLength uint8 // value of 0 means "use correct size" badHeaderChecksum bool maxTotalLength uint16 transportProtocol uint8 TTL uint8 options header.IPv4Options replyOptions header.IPv4Options // reply should look like this shouldFail bool expectErrorICMP bool ICMPType header.ICMPv4Type ICMPCode header.ICMPv4Code paramProblemPointer uint8 }{ { name: "valid no options", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, { name: "bad header checksum", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, badHeaderChecksum: true, shouldFail: true, }, // The TTL tests check that we are not rejecting an incoming packet // with a zero or one TTL, which has been a point of confusion in the // past as RFC 791 says: "If this field contains the value zero, then the // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies // for the case of the destination host, stating as follows. // // A host MUST NOT send a datagram with a Time-to-Live (TTL) // value of zero. // // A host MUST NOT discard a datagram just because it was // received with TTL less than 2. { name: "zero TTL", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 0, }, { name: "one TTL", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 1, }, { name: "End options", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{0, 0, 0, 0}, replyOptions: header.IPv4Options{0, 0, 0, 0}, }, { name: "NOP options", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{1, 1, 1, 1}, replyOptions: header.IPv4Options{1, 1, 1, 1}, }, { name: "NOP and End options", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{1, 1, 0, 0}, replyOptions: header.IPv4Options{1, 1, 0, 0}, }, { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, }, { name: "bad total length (0)", maxTotalLength: 0, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, }, { name: "bad total length (ip - 1)", maxTotalLength: uint16(header.IPv4MinimumSize - 1), transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, }, { name: "bad total length (ip + icmp - 1)", maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, }, { name: "bad protocol", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: 99, TTL: ttl, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4DstUnreachable, ICMPCode: header.ICMPv4ProtoUnreachable, }, { name: "timestamp option overflow", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, 1, 2, 3, 4, }, replyOptions: header.IPv4Options{ 68, 12, 13, 0x21, 192, 168, 1, 12, 1, 2, 3, 4, }, }, { name: "timestamp option overflow full", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 12, 13, 0xF1, // ^ Counter full (15/0xF) 192, 168, 1, 12, 1, 2, 3, 4, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 3, replyOptions: header.IPv4Options{}, }, { name: "unknown option", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{10, 4, 9, 0}, // ^^ // The unknown option should be stripped out of the reply. replyOptions: header.IPv4Options{}, }, { name: "bad option - no length", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 1, 1, 1, 68, // ^-start of timestamp.. but no length.. }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 3, }, { name: "bad option - length 0", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 0, 9, 0, // ^ 1, 2, 3, 4, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "bad option - length 1", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 1, 9, 0, // ^ 1, 2, 3, 4, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "bad option - length big", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 9, 9, 0, // ^ // There are only 8 bytes allocated to options so 9 bytes of timestamp // space is not possible. (Second byte) 1, 2, 3, 4, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { // This tests for some linux compatible behaviour. // The ICMP pointer returned is 22 for Linux but the // error is actually in spot 21. name: "bad option - length bad", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, // Timestamps are in multiples of 4 or 8 but never 7. // The option space should be padded out. options: header.IPv4Options{ 68, 7, 5, 0, // ^ ^ Linux points here which is wrong. // | Not a multiple of 4 1, 2, 3, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "multiple type 0 with room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 0, 0, }, replyOptions: header.IPv4Options{ 68, 24, 25, 0x00, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, }, { // The timestamp area is full so add to the overflow count. name: "multiple type 1 timestamps", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 20, 21, 0x11, // ^ 192, 168, 1, 12, 1, 2, 3, 4, 192, 168, 1, 13, 5, 6, 7, 8, }, // Overflow count is the top nibble of the 4th byte. replyOptions: header.IPv4Options{ 68, 20, 21, 0x21, // ^ 192, 168, 1, 12, 1, 2, 3, 4, 192, 168, 1, 13, 5, 6, 7, 8, }, }, { name: "multiple type 1 timestamps with room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 28, 21, 0x01, 192, 168, 1, 12, 1, 2, 3, 4, 192, 168, 1, 13, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, }, replyOptions: header.IPv4Options{ 68, 28, 29, 0x01, 192, 168, 1, 12, 1, 2, 3, 4, 192, 168, 1, 13, 5, 6, 7, 8, 192, 168, 1, 58, // New IP Address. 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, }, { // Timestamp pointer uses one based counting so 0 is invalid. name: "timestamp pointer invalid", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 8, 0, 0x00, // ^ 0 instead of 5 or more. 0, 0, 0, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 2, }, { // Timestamp pointer cannot be less than 5. It must point past the header // which is 4 bytes. (1 based counting) name: "timestamp pointer too small by 1", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, // ^ header is 4 bytes, so 4 should fail. 0, 0, 0, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "valid timestamp pointer", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, // ^ header is 4 bytes, so 5 should succeed. 0, 0, 0, 0, }, replyOptions: header.IPv4Options{ 68, 8, 9, 0x00, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, }, { // Needs 8 bytes for a type 1 timestamp but there are only 4 free. name: "bad timer element alignment", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 20, 17, 0x01, // ^^ ^^ 20 byte area, next free spot at 17. 192, 168, 1, 12, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, // End of option list with illegal option after it, which should be ignored. { name: "end of options list", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, 1, 2, 3, 4, 0, 10, 3, 99, // EOL followed by junk }, replyOptions: header.IPv4Options{ 68, 12, 13, 0x21, 192, 168, 1, 12, 1, 2, 3, 4, 0, // End of Options hides following bytes. 0, 0, 0, // 3 bytes unknown option removed. }, }, { // Timestamp with a size much too small. name: "timestamp truncated", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 68, 1, 0, 0, // ^ Smallest possible is 8. Linux points at the 68. }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "single record route with room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 7, 4, // 3 byte header 0, 0, 0, 0, 0, }, replyOptions: header.IPv4Options{ 7, 7, 8, // 3 byte header 192, 168, 1, 58, // New IP Address. 0, // padding to multiple of 4 bytes. }, }, { name: "multiple record route with room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 23, 20, // 3 byte header 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 0, 0, 0, }, replyOptions: header.IPv4Options{ 7, 23, 24, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 192, 168, 1, 58, // New IP Address. 0, // padding to multiple of 4 bytes. }, }, { name: "single record route with no room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 7, 8, // 3 byte header 1, 2, 3, 4, 0, }, replyOptions: header.IPv4Options{ 7, 7, 8, // 3 byte header 1, 2, 3, 4, 0, // padding to multiple of 4 bytes. }, }, { // Unlike timestamp, this should just succeed. name: "multiple record route with no room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 23, 24, // 3 byte header 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, }, replyOptions: header.IPv4Options{ 7, 23, 24, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, // padding to multiple of 4 bytes. }, }, { // Pointer uses one based counting so 0 is invalid. name: "record route pointer zero", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 8, 0, // 3 byte header 0, 0, 0, 0, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header // using 1 based counting. 3 should fail. name: "record route pointer too small by 1", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header 0, 0, 0, 0, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header // using 1 based counting. Check 4 passes. (Duplicates "single // record route with room") name: "valid record route pointer", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header 0, 0, 0, 0, 0, }, replyOptions: header.IPv4Options{ 7, 7, 8, // 3 byte header 192, 168, 1, 58, // New IP Address. 0, // padding to multiple of 4 bytes. }, }, { // Confirm Linux bug for bug compatibility. // Linux returns slot 22 but the error is in slot 21. name: "multiple record route with not enough room", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 8, 8, // 3 byte header // ^ ^ Linux points here. We must too. // | Not enough room. 1 byte free, need 4. 1, 2, 3, 4, 0, }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { name: "duplicate record route", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: header.IPv4Options{ 7, 7, 8, // 3 byte header 1, 2, 3, 4, 7, 7, 8, // 3 byte header 1, 2, 3, 4, 0, 0, // pad }, shouldFail: true, expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 7, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, Clock: clock, }) // Advance the clock by some unimportant amount to make // it give a more recognisable signature than 00,00,00,00. clock.Advance(time.Millisecond * randomTimeOffset) // We expect at most a single packet in response to our ICMP Echo Request. e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. s.SetRouteTable([]tcpip.Route{ { Destination: header.IPv4EmptySubnet, NIC: nicID, }, }) if len(test.options)%4 != 0 { t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) // Specify ident/seq to make sure we get the same in the response. icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv4Echo) icmp.SetCode(header.ICMPv4UnusedCode) icmp.SetChecksum(0) icmp.SetChecksum(^header.Checksum(icmp, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, TTL: test.TTL, SrcAddr: remoteIPv4Addr, DstAddr: ipv4Addr.Address, }) if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) } else { // Set the calculated header length, since we may manually add options. ip.SetHeaderLength(uint8(ipHeaderLength)) } if len(test.options) != 0 { // Copy options manually. We do not use Encode for options so we can // verify malformed options with handcrafted payloads. if want, got := copy(ip.Options(), test.options), len(test.options); want != got { t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) } } ip.SetChecksum(0) ipHeaderChecksum := ip.CalculateChecksum() if test.badHeaderChecksum { ipHeaderChecksum += 42 } ip.SetChecksum(^ipHeaderChecksum) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) reply, ok := e.Read() if !ok { if test.shouldFail { if test.expectErrorICMP { t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) } return // Expected silent failure. } t.Fatal("expected ICMP echo reply missing") } // We didn't expect a packet. Register our surprise but carry on to // provide more information about what we got. if test.shouldFail && !test.expectErrorICMP { t.Error("unexpected packet response") } // Check the route that brought the packet to us. if reply.Route.LocalAddress != ipv4Addr.Address { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) } if reply.Route.RemoteAddress != remoteIPv4Addr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) } // Make sure it's all in one buffer for checker. replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) // At this stage we only know it's probably an IP+ICMP header so verify // that much. checker.IPv4(t, replyIPHeader, checker.SrcAddr(ipv4Addr.Address), checker.DstAddr(remoteIPv4Addr), checker.ICMPv4( checker.ICMPv4Checksum(), ), ) // Don't proceed any further if the checker found problems. if t.Failed() { t.FailNow() } // OK it's ICMP. We can safely look at the type now. replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) switch replyICMPHeader.Type() { case header.ICMPv4ParamProblem: if !test.shouldFail { t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) } if !test.expectErrorICMP { t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), checker.ICMPv4Payload([]byte(hdr.View())), ), ) return case header.ICMPv4DstUnreachable: if !test.shouldFail { t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } if !test.expectErrorICMP { t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Payload([]byte(hdr.View())), ), ) return case header.ICMPv4EchoReply: if test.shouldFail { if !test.expectErrorICMP { t.Error("got Echo Reply packet, want no response") } else { t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) } } // If the IP options change size then the packet will change size, so // some IP header fields will need to be adjusted for the checks. sizeChange := len(test.replyOptions) - len(test.options) checker.IPv4(t, replyIPHeader, checker.IPv4HeaderLength(ipHeaderLength+sizeChange), checker.IPv4Options(test.replyOptions), checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Seq(randomSequence), checker.ICMPv4Ident(randomIdent), ), ) default: t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) } }) } } // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { // Make a complete array of the sourcePacket packet. source := header.IPv4(packets[0].NetworkHeader().View()) vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) sourceCopy.SetChecksum(0) sourceCopy.SetFlagsFragmentOffset(0, 0) sourceCopy.SetTotalLength(0) // Build up an array of the bytes sent. var reassembledPayload buffer.VectorisedView for i, packet := range packets { // Confirm that the packet is valid. allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) fragmentIPHeader := header.IPv4(allBytes.ToView()) if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) } if got := len(fragmentIPHeader); got > int(mtu) { return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } if got := fragmentIPHeader.TransportProtocol(); got != proto { return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) } if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) } if wantFragments[i].more { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) } else { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } reassembledPayload.AppendView(packet.TransportHeader().View()) reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView()) // Clear out the checksum and length from the ip because we can't compare // it. sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } } expected := buffer.View(source[source.HeaderLength():]) if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } return nil } type fragmentInfo struct { offset uint16 more bool payloadSize uint16 } var fragmentationTests = []struct { description string mtu uint32 gso *stack.GSO transportHeaderLength int payloadSize int wantFragments []fragmentInfo }{ { description: "No fragmentation", mtu: 1280, gso: nil, transportHeaderLength: 0, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1000, more: false}, }, }, { description: "Fragmented", mtu: 1280, gso: nil, transportHeaderLength: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1256, more: true}, {offset: 1256, payloadSize: 744, more: false}, }, }, { description: "Fragmented with the minimum mtu", mtu: header.IPv4MinimumMTU, gso: nil, transportHeaderLength: 0, payloadSize: 100, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 48, more: true}, {offset: 48, payloadSize: 48, more: true}, {offset: 96, payloadSize: 4, more: false}, }, }, { description: "Fragmented with mtu not a multiple of 8", mtu: header.IPv4MinimumMTU + 1, gso: nil, transportHeaderLength: 0, payloadSize: 100, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 48, more: true}, {offset: 48, payloadSize: 48, more: true}, {offset: 96, payloadSize: 4, more: false}, }, }, { description: "No fragmentation with big header", mtu: 2000, gso: nil, transportHeaderLength: 100, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1100, more: false}, }, }, { description: "Fragmented with gso none", mtu: 1280, gso: &stack.GSO{Type: stack.GSONone}, transportHeaderLength: 0, payloadSize: 1400, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1256, more: true}, {offset: 1256, payloadSize: 144, more: false}, }, }, { description: "Fragmented with big header", mtu: 1280, gso: nil, transportHeaderLength: 100, payloadSize: 1200, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1256, more: true}, {offset: 1256, payloadSize: 44, more: false}, }, }, { description: "Fragmented with MTU smaller than header", mtu: 300, gso: nil, transportHeaderLength: 1000, payloadSize: 500, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 280, more: true}, {offset: 280, payloadSize: 280, more: true}, {offset: 560, payloadSize: 280, more: true}, {offset: 840, payloadSize: 280, more: true}, {offset: 1120, payloadSize: 280, more: true}, {offset: 1400, payloadSize: 100, more: false}, }, }, } func TestFragmentationWritePacket(t *testing.T) { const ttl = 42 for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != nil { t.Fatalf("r.WritePacket(_, _, _) = %s", err) } if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) } if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) } if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { t.Error(err) } }) } } func TestFragmentationWritePackets(t *testing.T) { const ttl = 42 writePacketsTests := []struct { description string insertBefore int insertAfter int }{ { description: "Single packet", insertBefore: 0, insertAfter: 0, }, { description: "With packet before", insertBefore: 1, insertAfter: 0, }, { description: "With packet after", insertBefore: 0, insertAfter: 1, }, { description: "With packet before and after", insertBefore: 1, insertAfter: 1, }, } tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { var pkts stack.PacketBufferList for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }) if err != nil { t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) } if n != wantTotalPackets { t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) } if got := len(ep.WrittenPackets); got != wantTotalPackets { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) } if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) } if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } if wantTotalPackets == 0 { return } fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { t.Error(err) } }) } }) } } // TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { const ttl = 42 tests := []struct { description string mtu uint32 transportHeaderLength int payloadSize int allowPackets int outgoingErrors int mockError tcpip.Error wantError tcpip.Error }{ { description: "No frag", mtu: 2000, payloadSize: 1000, transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 1, mockError: &tcpip.ErrAborted{}, wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", mtu: 500, payloadSize: 1000, transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 3, mockError: &tcpip.ErrAborted{}, wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", mtu: 500, payloadSize: 1000, transportHeaderLength: 0, allowPackets: 1, outgoingErrors: 2, mockError: &tcpip.ErrAborted{}, wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag MTU smaller than header", mtu: 500, transportHeaderLength: 1000, payloadSize: 500, allowPackets: 0, outgoingErrors: 4, mockError: &tcpip.ErrAborted{}, wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than IPv4 minimum MTU", mtu: header.IPv4MinimumMTU - 1, transportHeaderLength: 0, payloadSize: 500, allowPackets: 0, outgoingErrors: 1, mockError: nil, wantError: &tcpip.ErrInvalidEndpointState{}, }, } for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if diff := cmp.Diff(ft.wantError, err); diff != "" { t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } } func TestInvalidFragments(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") addr1 = "\x0a\x00\x00\x01" addr2 = "\x0a\x00\x00\x02" tos = 0 ident = 1 ttl = 48 protocol = 6 ) payloadGen := func(payloadLen int) []byte { payload := make([]byte, payloadLen) for i := 0; i < len(payload); i++ { payload[i] = 0x30 } return payload } type fragmentData struct { ipv4fields header.IPv4Fields // 0 means insert the correct IHL. Non 0 means override the correct IHL. overrideIHL int // For 0 use 1 as it is an int and will be divided by 4. payload []byte autoChecksum bool // If true, the Checksum field will be overwritten. } tests := []struct { name string fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 }{ { name: "IHL and TotalLength zero, FragmentOffset non-zero", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: 0, ID: ident, Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, FragmentOffset: 59776, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, overrideIHL: 1, // See note above. payload: payloadGen(12), autoChecksum: true, }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 0, }, { name: "IHL and TotalLength zero, FragmentOffset zero", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: 0, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, overrideIHL: 1, // See note above. payload: payloadGen(12), autoChecksum: true, }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 0, }, { // Payload 17 octets and Fragment offset 65520 // Leading to the fragment end to be past 65536. name: "fragment ends past 65536", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 17, ID: ident, Flags: 0, FragmentOffset: 65520, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(17), autoChecksum: true, }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, }, { // Payload 16 octets and fragment offset 65520 // Leading to the fragment end to be exactly 65536. name: "fragment ends exactly at 65536", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 16, ID: ident, Flags: 0, FragmentOffset: 65520, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(16), autoChecksum: true, }, }, wantMalformedIPPackets: 0, wantMalformedFragments: 0, }, { name: "IHL less than IPv4 minimum size", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 28, ID: ident, Flags: 0, FragmentOffset: 1944, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(28), overrideIHL: header.IPv4MinimumSize - 12, autoChecksum: true, }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize - 12, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(28), overrideIHL: header.IPv4MinimumSize - 12, autoChecksum: true, }, }, wantMalformedIPPackets: 2, wantMalformedFragments: 0, }, { name: "fragment with short TotalLength and extra payload", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 28, ID: ident, Flags: 0, FragmentOffset: 28816, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(28), overrideIHL: header.IPv4MinimumSize + 4, autoChecksum: true, }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 4, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(28), overrideIHL: header.IPv4MinimumSize + 4, autoChecksum: true, }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, }, { name: "multiple fragments with More Fragments flag set to false", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, Flags: 0, FragmentOffset: 128, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(8), autoChecksum: true, }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, Flags: 0, FragmentOffset: 8, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(8), autoChecksum: true, }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: payloadGen(8), autoChecksum: true, }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, }, }) e := channel.New(0, 1500, linkAddr) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) } for _, f := range test.fragments { pktSize := header.IPv4MinimumSize + len(f.payload) hdr := buffer.NewPrependable(pktSize) ip := header.IPv4(hdr.Prepend(pktSize)) ip.Encode(&f.ipv4fields) if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got { t.Fatalf("copied %d bytes, expected %d bytes.", got, want) } // Encode sets this up correctly. If we want a different value for // testing then we need to overwrite the good value. if f.overrideIHL != 0 { ip.SetHeaderLength(uint8(f.overrideIHL)) // If we are asked to add options (type not specified) then pad // with 0 (EOL). RFC 791 page 23 says "The padding is zero". for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ { ip[i] = byte(header.IPv4OptionListEndType) } } if f.autoChecksum { ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) } vv := hdr.View().ToVectorisedView() e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, })) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) } if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) } }) } } func TestFragmentReassemblyTimeout(t *testing.T) { const ( nicID = 1 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") addr1 = "\x0a\x00\x00\x01" addr2 = "\x0a\x00\x00\x02" tos = 0 ident = 1 ttl = 48 protocol = 99 data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" ) type fragmentData struct { ipv4fields header.IPv4Fields payload []byte } tests := []struct { name string fragments []fragmentData expectICMP bool }{ { name: "first fragment only", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 16, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[:16], }, }, expectICMP: true, }, { name: "two first fragments", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 16, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[:16], }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 16, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[:16], }, }, expectICMP: true, }, { name: "second fragment only", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), ID: ident, Flags: 0, FragmentOffset: 8, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[16:], }, }, expectICMP: false, }, { name: "two fragments with a gap", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[:8], }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), ID: ident, Flags: 0, FragmentOffset: 16, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[16:], }, }, expectICMP: true, }, { name: "two fragments with a gap in reverse order", fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), ID: ident, Flags: 0, FragmentOffset: 16, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[16:], }, { ipv4fields: header.IPv4Fields{ TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, Flags: header.IPv4FlagMoreFragments, FragmentOffset: 0, TTL: ttl, Protocol: protocol, SrcAddr: addr1, DstAddr: addr2, }, payload: []byte(data)[:8], }, }, expectICMP: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, }, Clock: clock, }) e := channel.New(1, 1500, linkAddr) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: nicID, }}) var firstFragmentSent buffer.View for _, f := range test.fragments { pktSize := header.IPv4MinimumSize hdr := buffer.NewPrependable(pktSize) ip := header.IPv4(hdr.Prepend(pktSize)) ip.Encode(&f.ipv4fields) ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, }) if firstFragmentSent == nil && ip.FragmentOffset() == 0 { firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) } e.InjectInbound(header.IPv4ProtocolNumber, pkt) } clock.Advance(ipv4.ReassembleTimeout) reply, ok := e.Read() if !test.expectICMP { if ok { t.Fatalf("unexpected ICMP error message received: %#v", reply) } return } if !ok { t.Fatal("expected ICMP error message missing") } if firstFragmentSent == nil { t.Fatalf("unexpected ICMP error message received: %#v", reply) } checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(addr2), checker.DstAddr(addr1), checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), checker.ICMPv4Payload([]byte(firstFragmentSent)), ), ) }) } } // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { const ( nicID = 1 addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 ) // Build and return a UDP header containing payload. udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { payload := buffer.NewView(payloadLen) for i := 0; i < len(payload); i++ { payload[i] = uint8(i) * multiplier } udpLength := header.UDPMinimumSize + len(payload) hdr := buffer.NewPrependable(udpLength) u := header.UDP(hdr.Prepend(udpLength)) u.Encode(&header.UDPFields{ SrcPort: 5555, DstPort: 80, Length: uint16(udpLength), }) copy(u.Payload(), payload) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) sum = header.Checksum(payload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) return hdr.View() } // UDP header plus a payload of 0..256 ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] // UDP header plus a payload of 0..256 in increments of 2. ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] // UDP header plus a payload of 0..256 in increments of 3. // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 791 section 3.1 page 14). ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] // Used to test the max reassembled IPv4 payload length. ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] type fragmentData struct { srcAddr tcpip.Address dstAddr tcpip.Address id uint16 flags uint8 fragmentOffset uint16 payload buffer.View } tests := []struct { name string fragments []fragmentData expectedPayloads [][]byte }{ { name: "No fragmentation", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "No fragmentation with size not a multiple of fragment block size", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 0, payload: ipv4Payload3Addr1ToAddr2, }, }, expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, }, { name: "More fragments without payload", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: nil, }, { name: "Non-zero fragment offset without payload", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 8, payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: nil, }, { name: "Two fragments", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Two fragments out of order", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, payload: ipv4Payload1Addr1ToAddr2[64:], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, }, expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Two fragments with last fragment size not a multiple of fragment block size", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload3Addr1ToAddr2[:64], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, payload: ipv4Payload3Addr1ToAddr2[64:], }, }, expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, }, { name: "Two fragments with first fragment size not a multiple of fragment block size", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload3Addr1ToAddr2[:63], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 63, payload: ipv4Payload3Addr1ToAddr2[63:], }, }, expectedPayloads: nil, }, { name: "Second fragment has MoreFlags set", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 64, payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: nil, }, { name: "Two fragments with different IDs", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, { srcAddr: addr1, dstAddr: addr2, id: 2, flags: 0, fragmentOffset: 64, payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: nil, }, { name: "Two interleaved fragmented packets", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, { srcAddr: addr1, dstAddr: addr2, id: 2, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload2Addr1ToAddr2[:64], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, payload: ipv4Payload1Addr1ToAddr2[64:], }, { srcAddr: addr1, dstAddr: addr2, id: 2, flags: 0, fragmentOffset: 64, payload: ipv4Payload2Addr1ToAddr2[64:], }, }, expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, }, { name: "Two interleaved fragmented packets from different sources but with same ID", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, { srcAddr: addr3, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr3ToAddr2[:32], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, payload: ipv4Payload1Addr1ToAddr2[64:], }, { srcAddr: addr3, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 32, payload: ipv4Payload1Addr3ToAddr2[32:], }, }, expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, }, { name: "Fragment without followup", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload1Addr1ToAddr2[:64], }, }, expectedPayloads: nil, }, { name: "Two fragments reassembled into a maximum UDP packet", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload4Addr1ToAddr2[:65512], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 65512, payload: ipv4Payload4Addr1ToAddr2[65512:], }, }, expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, { name: "Two fragments with MF flag reassembled into a maximum UDP packet", fragments: []fragmentData{ { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, payload: ipv4Payload4Addr1ToAddr2[:65512], }, { srcAddr: addr1, dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 65512, payload: ipv4Payload4Addr1ToAddr2[65512:], }, }, expectedPayloads: nil, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Setup a stack and endpoint. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) } wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) defer close(ch) ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) } defer ep.Close() bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} if err := ep.Bind(bindAddr); err != nil { t.Fatalf("Bind(%+v): %s", bindAddr, err) } // Bring up a raw endpoint so we can examine network headers. epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) if err != nil { t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) } defer epRaw.Close() // Prepare and send the fragments. for _, frag := range test.fragments { hdr := buffer.NewPrependable(header.IPv4MinimumSize) // Serialize IPv4 fixed header. ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), ID: frag.id, Flags: frag.flags, FragmentOffset: frag.fragmentOffset, TTL: 64, Protocol: uint8(header.UDPProtocolNumber), SrcAddr: frag.srcAddr, DstAddr: frag.dstAddr, }) ip.SetChecksum(^ip.CalculateChecksum()) vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } for i, expectedPayload := range test.expectedPayloads { // Check UDP payload delivered by UDP endpoint. var buf bytes.Buffer result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("(i=%d) ep.Read: %s", i, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: len(expectedPayload), Total: len(expectedPayload), }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) } if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) } // Check IPv4 header in packet delivered by raw endpoint. buf.Reset() result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("(i=%d) epRaw.Read: %s", i, err) } // Reassambly does not take care of checksum. Here we write our own // check routine instead of using checker.IPv4. ip := header.IPv4(buf.Bytes()) for _, check := range []checker.NetworkChecker{ checker.FragmentFlags(0), checker.FragmentOffset(0), checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), } { check(t, []header.Network{ip}) } } res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } } func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { name string setup func(*testing.T, *stack.Stack) allowPackets int expectSent int expectDropped int expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, allowPackets: math.MaxInt32, expectSent: nPackets, expectDropped: 0, expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, allowPackets: nPackets - 1, expectSent: nPackets - 1, expectDropped: 0, expectWritten: nPackets - 1, }, { name: "Drop all", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, allowPackets: math.MaxInt32, expectSent: 0, expectDropped: nPackets, expectWritten: nPackets, }, { name: "Drop some", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, allowPackets: math.MaxInt32, expectSent: nPackets - 1, expectDropped: 1, expectWritten: nPackets, }, } // Parameterize the tests to run with both WritePacket and WritePackets. writers := []struct { name string writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { return nWritten, err } nWritten++ } return nWritten, nil }, }, { name: "WritePackets", writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, } for _, writer := range writers { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), Data: buffer.NewView(0).ToVectorisedView(), }) pkt.TransportHeader().Push(header.UDPMinimumSize) pkts.PushBack(pkt) } test.setup(t, rt.Stack()) nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) } if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) } if nWritten != test.expectWritten { t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) } }) } }) } } func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC(1, _) failed: %s", err) } const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" ) if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) } { mask := tcpip.AddressMask(header.IPv4Broadcast) subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, NIC: 1, }}) } rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) } return rt } // limitedMatcher is an iptables matcher that matches after a certain number of // packets are checked against it. type limitedMatcher struct { limit int } // Name implements Matcher.Name. func (*limitedMatcher) Name() string { return "limitedMatcher" } // Match implements Matcher.Match. func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } lm.limit-- return false, false } func TestPacketQueing(t *testing.T) { const nicID = 1 var ( host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 8, }, } ) tests := []struct { name string rxPkt func(*channel.Endpoint) checkResp func(*testing.T, *channel.Endpoint) }{ { name: "ICMP Error", rxPkt: func(e *channel.Endpoint) { hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) u.Encode(&header.UDPFields{ SrcPort: 5555, DstPort: 80, Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) sum = header.Checksum(header.UDP([]byte{}), sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, TTL: ipv4.DefaultTTL, Protocol: uint8(udp.ProtocolNumber), SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { p, ok := e.ReadContext(context.Background()) if !ok { t.Fatalf("timed out waiting for packet") } if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } if p.Route.RemoteLinkAddress != host2NICLinkAddr { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), checker.ICMPv4( checker.ICMPv4Type(header.ICMPv4DstUnreachable), checker.ICMPv4Code(header.ICMPv4PortUnreachable))) }, }, { name: "Ping", rxPkt: func(e *channel.Endpoint) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) pkt.SetType(header.ICMPv4Echo) pkt.SetCode(0) pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, 0)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ TotalLength: uint16(totalLen), Protocol: uint8(icmp.ProtocolNumber4), TTL: ipv4.DefaultTTL, SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { p, ok := e.ReadContext(context.Background()) if !ok { t.Fatalf("timed out waiting for packet") } if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } if p.Route.RemoteLinkAddress != host2NICLinkAddr { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), checker.ICMPv4( checker.ICMPv4Type(header.ICMPv4EchoReply), checker.ICMPv4Code(header.ICMPv4UnusedCode))) }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) } s.SetRouteTable([]tcpip.Route{ { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), NIC: nicID, }, }) // Receive a packet to trigger link resolution before a response is sent. test.rxPkt(e) // Wait for a ARP request since link address resolution should be // performed. { p, ok := e.ReadContext(context.Background()) if !ok { t.Fatalf("timed out waiting for packet") } if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) } if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) } if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) } if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) } } // Send an ARP reply to complete link address resolution. { hdr := buffer.View(make([]byte, header.ARPSize)) packet := header.ARP(hdr) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), host2NICLinkAddr) copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) copy(packet.HardwareAddressTarget(), host1NICLinkAddr) copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.ToVectorisedView(), })) } // Expect the response now that the link address has resolved. test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed // again. test.rxPkt(e) test.checkResp(t, e) }) } } // TestCloseLocking test that lock ordering is followed when closing an // endpoint. func TestCloseLocking(t *testing.T) { const ( nicID1 = 1 nicID2 = 2 src = tcpip.Address("\x10\x00\x00\x01") dst = tcpip.Address("\x10\x00\x00\x02") iterations = 1000 ) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) // Perform NAT so that the endoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [stack.NumHooks]int{ stack.Prerouting: 0, stack.Input: 1, stack.Forward: stack.HookUnset, stack.Output: 2, stack.Postrouting: 3, }, Underflows: [stack.NumHooks]int{ stack.Prerouting: 0, stack.Input: 1, stack.Forward: stack.HookUnset, stack.Output: 2, stack.Postrouting: 3, }, } if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) } e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: nicID1, }}) var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatal(err) } defer ep.Close() addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} if err := ep.Connect(addr); err != nil { t.Errorf("ep.Connect(%#v): %s", addr, err) } var wg sync.WaitGroup defer wg.Wait() // Writing packets should trigger NAT which requires the stack to search the // protocol for network endpoints with the destination address. // // Creating and removing interfaces should modify the protocol and endpoint // which requires taking the locks of each. // // We expect the protocol > endpoint lock ordering to be followed here. wg.Add(2) go func() { defer wg.Done() data := []byte{1, 2, 3, 4} for i := 0; i < iterations; i++ { var r bytes.Reader r.Reset(data) if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Errorf("ep.Write(_, _): %s", err) return } else if want := int64(len(data)); n != want { t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) return } } }() go func() { defer wg.Done() for i := 0; i < iterations; i++ { if err := s.CreateNIC(nicID2, loopback.New()); err != nil { t.Errorf("CreateNIC(%d, _): %s", nicID2, err) return } if err := s.RemoveNIC(nicID2); err != nil { t.Errorf("RemoveNIC(%d): %s", nicID2, err) return } } }() }