summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD22
-rw-r--r--pkg/tcpip/network/arp/BUILD32
-rw-r--r--pkg/tcpip/network/arp/arp.go217
-rw-r--r--pkg/tcpip/network/arp/arp_test.go146
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD45
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go144
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go165
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go118
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go105
-rw-r--r--pkg/tcpip/network/hash/BUILD13
-rw-r--r--pkg/tcpip/network/hash/hash.go93
-rw-r--r--pkg/tcpip/network/ip_test.go655
-rw-r--r--pkg/tcpip/network/ipv4/BUILD38
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go160
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go598
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go475
-rw-r--r--pkg/tcpip/network/ipv6/BUILD44
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go546
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go960
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go521
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go1265
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go906
24 files changed, 7471 insertions, 0 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
new file mode 100644
index 000000000..6a4839fb8
--- /dev/null
+++ b/pkg/tcpip/network/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = [
+ "ip_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
new file mode 100644
index 000000000..eddf7b725
--- /dev/null
+++ b/pkg/tcpip/network/arp/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "arp",
+ srcs = ["arp.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "arp_test",
+ size = "small",
+ srcs = ["arp_test.go"],
+ deps = [
+ ":arp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
new file mode 100644
index 000000000..9d0797af7
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp.go
@@ -0,0 +1,217 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package arp implements the ARP network protocol. It is used to resolve
+// IPv4 addresses into link-local MAC addresses, and advertises IPv4
+// addresses of its stack with the local network.
+//
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
+//
+// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
+// // handle err
+// }
+package arp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolNumber is the ARP protocol number.
+ ProtocolNumber = header.ARPProtocolNumber
+
+ // ProtocolAddress is the address expected by the ARP endpoint.
+ ProtocolAddress = tcpip.Address("arp")
+)
+
+// endpoint implements stack.NetworkEndpoint.
+type endpoint struct {
+ protocol *protocol
+ nicID tcpip.NICID
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+}
+
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &stack.NetworkEndpointID{ProtocolAddress}
+}
+
+func (e *endpoint) PrefixLen() int {
+ return 0
+}
+
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.ARPSize
+}
+
+func (e *endpoint) Close() {}
+
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
+ v, ok := pkt.Data.PullUp(header.ARPSize)
+ if !ok {
+ return
+ }
+ h := header.ARP(v)
+ if !h.IsValid() {
+ return
+ }
+
+ switch h.Op() {
+ case header.ARPRequest:
+ localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
+ packet := header.ARP(hdr.Prepend(header.ARPSize))
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ Header: hdr,
+ })
+ fallthrough // also fill the cache from requests
+ case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ }
+}
+
+// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
+type protocol struct {
+}
+
+func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
+func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+func (p *protocol) DefaultPrefixLen() int { return 0 }
+
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.ARP(v)
+ return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+}
+
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addrWithPrefix.Address != ProtocolAddress {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return &endpoint{
+ protocol: p,
+ nicID: nicID,
+ linkEP: sender,
+ linkAddrCache: linkAddrCache,
+ }, nil
+}
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ r := &stack.Route{
+ RemoteLinkAddress: broadcastMAC,
+ }
+
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
+ h := header.ARP(hdr.Prepend(header.ARPSize))
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), linkEP.LinkAddress())
+ copy(h.ProtocolAddressSender(), localAddr)
+ copy(h.ProtocolAddressTarget(), addr)
+
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ Header: hdr,
+ })
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == header.IPv4Broadcast {
+ return broadcastMAC, true
+ }
+ if header.IsV4MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv4Address(addr), true
+ }
+ return tcpip.LinkAddress([]byte(nil)), false
+}
+
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
new file mode 100644
index 000000000..1646d9cde
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arp_test
+
+import (
+ "context"
+ "strconv"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+const (
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+)
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+}
+
+func newTestContext(t *testing.T) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
+
+ const defaultMTU = 65536
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNIC(1, wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("AddAddress for arp failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ }})
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: ep,
+ }
+}
+
+func (c *testContext) cleanup() {
+ c.linkEP.Close()
+}
+
+func TestDirectRequest(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+
+ const senderMAC = "\x01\x02\x03\x04\x05\x06"
+ const senderIPv4 = "\x0a\x00\x00\x02"
+
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), senderMAC)
+ copy(h.ProtocolAddressSender(), senderIPv4)
+
+ inject := func(addr tcpip.Address) {
+ copy(h.ProtocolAddressTarget(), addr)
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
+ }
+
+ for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ inject(address)
+ pi, _ := c.linkEP.ReadContext(context.Background())
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
+ }
+ rep := header.ARP(pi.Pkt.Header.View())
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
+ }
+ })
+ }
+
+ inject(stackAddrBad)
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
+ t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
new file mode 100644
index 000000000..d1c728ccf
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "reassembler_list",
+ out = "reassembler_list.go",
+ package = "fragmentation",
+ prefix = "reassembler",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*reassembler",
+ "Linker": "*reassembler",
+ },
+)
+
+go_library(
+ name = "fragmentation",
+ srcs = [
+ "frag_heap.go",
+ "fragmentation.go",
+ "reassembler.go",
+ "reassembler_list.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ ],
+)
+
+go_test(
+ name = "fragmentation_test",
+ size = "small",
+ srcs = [
+ "frag_heap_test.go",
+ "fragmentation_test.go",
+ "reassembler_test.go",
+ ],
+ library = ":fragmentation",
+ deps = ["//pkg/tcpip/buffer"],
+)
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
new file mode 100644
index 000000000..0b570d25a
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+type fragment struct {
+ offset uint16
+ vv buffer.VectorisedView
+}
+
+type fragHeap []fragment
+
+func (h *fragHeap) Len() int {
+ return len(*h)
+}
+
+func (h *fragHeap) Less(i, j int) bool {
+ return (*h)[i].offset < (*h)[j].offset
+}
+
+func (h *fragHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *fragHeap) Push(x interface{}) {
+ *h = append(*h, x.(fragment))
+}
+
+func (h *fragHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
+
+// reassamble empties the heap and returns a VectorisedView
+// containing a reassambled version of the fragments inside the heap.
+func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
+ curr := heap.Pop(h).(fragment)
+ views := curr.vv.Views()
+ size := curr.vv.Size()
+
+ if curr.offset != 0 {
+ return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ }
+
+ for h.Len() > 0 {
+ curr := heap.Pop(h).(fragment)
+ if int(curr.offset) < size {
+ curr.vv.TrimFront(size - int(curr.offset))
+ } else if int(curr.offset) > size {
+ return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ }
+ size += curr.vv.Size()
+ views = append(views, curr.vv.Views()...)
+ }
+ return buffer.NewVectorisedView(size, views), nil
+}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
new file mode 100644
index 000000000..9ececcb9f
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+var reassambleTestCases = []struct {
+ comment string
+ in []fragment
+ want buffer.VectorisedView
+}{
+ {
+ comment: "Non-overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Non-overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Duplicated packets",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(1, "0"),
+ },
+ {
+ comment: "Overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(2, "01")},
+ {offset: 1, vv: vv(2, "12")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(2, "12")},
+ {offset: 0, vv: vv(2, "01")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping subset in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(3, "012")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(3, "012"),
+ },
+ {
+ comment: "Overlapping subset out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(3, "012")},
+ },
+ want: vv(3, "012"),
+ },
+}
+
+func TestReassamble(t *testing.T) {
+ for _, c := range reassambleTestCases {
+ t.Run(c.comment, func(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ for _, f := range c.in {
+ heap.Push(&h, f)
+ }
+ got, err := h.reassemble()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, c.want) {
+ t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
+ }
+ })
+ }
+}
+
+func TestReassambleFailsForNonZeroOffset(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when the first packet had offset != 0")
+ }
+}
+
+func TestReassambleFailsForHoles(t *testing.T) {
+ h := make(fragHeap, 0, 8)
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
+ heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when there was a hole in the packet")
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
new file mode 100644
index 000000000..f42abc4bb
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -0,0 +1,144 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791 and RFC 815.
+package fragmentation
+
+import (
+ "fmt"
+ "log"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
+
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[uint32]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
+ reassemblers: make(map[uint32]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ }
+}
+
+// Process processes an incoming fragment belonging to an ID
+// and returns a complete packet when all the packets belonging to that ID have been received.
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if ok && r.tooOld(f.timeout) {
+ // This is very likely to be an id-collision or someone performing a slow-rate attack.
+ f.release(r)
+ ok = false
+ }
+ if !ok {
+ r = newReassembler(id)
+ f.reassemblers[id] = r
+ f.rList.PushFront(r)
+ }
+ f.mu.Unlock()
+
+ res, done, consumed, err := r.process(first, last, more, vv)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r)
+ f.mu.Unlock()
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ }
+ f.mu.Lock()
+ f.size += consumed
+ if done {
+ f.release(r)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ for f.size > f.lowLimit {
+ tail := f.rList.Back()
+ if tail == nil {
+ break
+ }
+ f.release(tail)
+ }
+ }
+ f.mu.Unlock()
+ return res, done, nil
+}
+
+func (f *Fragmentation) release(r *reassembler) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.size -= r.size
+ if f.size < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
+ f.size = 0
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
new file mode 100644
index 000000000..72c0f53be
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -0,0 +1,165 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// vv is a helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) buffer.VectorisedView {
+ views := make([]buffer.View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ return buffer.NewVectorisedView(size, views)
+}
+
+type processInput struct {
+ id uint32
+ first uint16
+ last uint16
+ more bool
+ vv buffer.VectorisedView
+}
+
+type processOutput struct {
+ vv buffer.VectorisedView
+ done bool
+}
+
+var processTestCases = []struct {
+ comment string
+ in []processInput
+ out []processOutput
+}{
+ {
+ comment: "One ID",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+ {
+ comment: "Two IDs",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "ab", "cd"), done: true},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+}
+
+func TestFragmentationProcess(t *testing.T) {
+ for _, c := range processTestCases {
+ t.Run(c.comment, func(t *testing.T) {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ for i, in := range c.in {
+ vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ }
+ if !reflect.DeepEqual(vv, c.out[i].vv) {
+ t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
+ }
+ if done != c.out[i].done {
+ t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
+ }
+ if c.out[i].done {
+ if _, ok := f.reassemblers[in.id]; ok {
+ t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
+ }
+ for n := f.rList.Front(); n != nil; n = n.Next() {
+ if n.id == in.id {
+ t.Errorf("Process(%d) did not remove buffer from rList", i)
+ }
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestReassemblingTimeout(t *testing.T) {
+ timeout := time.Millisecond
+ f := NewFragmentation(1024, 512, timeout)
+ // Send first fragment with id = 0, first = 0, last = 0, and more = true.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Sleep more than the timeout.
+ time.Sleep(2 * timeout)
+ // Send another fragment that completes a packet.
+ // However, no packet should be reassembled because the fragment arrived after the timeout.
+ _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ if err != nil {
+ t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ }
+ if done {
+ t.Errorf("Fragmentation does not respect the reassembling timeout.")
+ }
+}
+
+func TestMemoryLimits(t *testing.T) {
+ f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send first fragment with id = 1.
+ f.Process(1, 0, 0, true, vv(1, "1"))
+ // Send first fragment with id = 2.
+ f.Process(2, 0, 0, true, vv(1, "2"))
+
+ // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
+ // evicted.
+ f.Process(3, 0, 0, true, vv(1, "3"))
+
+ if _, ok := f.reassemblers[0]; ok {
+ t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[1]; ok {
+ t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[3]; !ok {
+ t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
+ }
+}
+
+func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
+ f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send the same packet again.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+
+ got := f.size
+ want := 1
+ if got != want {
+ t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
new file mode 100644
index 000000000..0a83d81f2
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ deleted bool
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id uint32
+ size int
+ mu sync.Mutex
+ holes []hole
+ deleted int
+ heap fragHeap
+ done bool
+ creationTime time.Time
+}
+
+func newReassembler(id uint32) *reassembler {
+ r := &reassembler{
+ id: id,
+ holes: make([]hole, 0, 16),
+ deleted: 0,
+ heap: make(fragHeap, 0, 8),
+ creationTime: time.Now(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ deleted: false})
+ return r
+}
+
+// updateHoles updates the list of holes for an incoming fragment and
+// returns true iff the fragment filled at least part of an existing hole.
+func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
+ used := false
+ for i := range r.holes {
+ if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ continue
+ }
+ used = true
+ r.deleted++
+ r.holes[i].deleted = true
+ if first > r.holes[i].first {
+ r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+ }
+ if last < r.holes[i].last && more {
+ r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ }
+ }
+ return used
+}
+
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ consumed := 0
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.VectorisedView{}, false, consumed, nil
+ }
+ if r.updateHoles(first, last, more) {
+ // We store the incoming packet only if it filled some holes.
+ heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
+ consumed = vv.Size()
+ r.size += consumed
+ }
+ // Check if all the holes have been deleted and we are ready to reassamble.
+ if r.deleted < len(r.holes) {
+ return buffer.VectorisedView{}, false, consumed, nil
+ }
+ res, err := r.heap.reassemble()
+ if err != nil {
+ return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
+ }
+ return res, true, consumed, nil
+}
+
+func (r *reassembler) tooOld(timeout time.Duration) bool {
+ return time.Now().Sub(r.creationTime) > timeout
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
new file mode 100644
index 000000000..7eee0710d
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "math"
+ "reflect"
+ "testing"
+)
+
+type updateHolesInput struct {
+ first uint16
+ last uint16
+ more bool
+}
+
+var holesTestCases = []struct {
+ comment string
+ in []updateHolesInput
+ want []hole
+}{
+ {
+ comment: "No fragments. Expected holes: {[0 -> inf]}.",
+ in: []updateHolesInput{},
+ want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
+ },
+ {
+ comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ {first: 3, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment at the end. Expected holes: {[0, 0]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 1, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 2, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 3, last: math.MaxUint16, deleted: true},
+ },
+ },
+}
+
+func TestUpdateHoles(t *testing.T) {
+ for _, c := range holesTestCases {
+ r := newReassembler(0)
+ for _, i := range c.in {
+ r.updateHoles(i.first, i.last, i.more)
+ }
+ if !reflect.DeepEqual(r.holes, c.want) {
+ t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
new file mode 100644
index 000000000..872165866
--- /dev/null
+++ b/pkg/tcpip/network/hash/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hash",
+ srcs = ["hash.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/rand",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
new file mode 100644
index 000000000..8f65713c5
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hash contains utility functions for hashing.
+package hash
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+var hashIV = RandN32(1)[0]
+
+// RandN32 generates a slice of n cryptographic random 32-bit numbers.
+func RandN32(n int) []uint32 {
+ b := make([]byte, 4*n)
+ if _, err := rand.Read(b); err != nil {
+ panic("unable to get random numbers: " + err.Error())
+ }
+ r := make([]uint32, n)
+ for i := range r {
+ r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
+ }
+ return r
+}
+
+// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
+// from linux.
+func Hash3Words(a, b, c, initval uint32) uint32 {
+ const iv = 0xdeadbeef + (3 << 2)
+ initval += iv
+
+ a += initval
+ b += initval
+ c += initval
+
+ c ^= b
+ c -= rol32(b, 14)
+ a ^= c
+ a -= rol32(c, 11)
+ b ^= a
+ b -= rol32(a, 25)
+ c ^= b
+ c -= rol32(b, 16)
+ a ^= c
+ a -= rol32(c, 4)
+ b ^= a
+ b -= rol32(a, 14)
+ c ^= b
+ c -= rol32(b, 24)
+
+ return c
+}
+
+// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
+func IPv4FragmentHash(h header.IPv4) uint32 {
+ x := uint32(h.ID())<<16 | uint32(h.Protocol())
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(x, y, z, hashIV)
+}
+
+// IPv6FragmentHash computes the hash of the ipv6 fragment.
+// Unlike IPv4, the protocol is not used to compute the hash.
+// RFC 2640 (sec 4.5) is not very sharp on this aspect.
+// As a reference, also Linux ignores the protocol to compute
+// the hash (inet6_hash_frag).
+func IPv6FragmentHash(h header.IPv6, id uint32) uint32 {
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(id, y, z, hashIV)
+}
+
+func rol32(v, shift uint32) uint32 {
+ return (v << shift) | (v >> ((-shift) & 31))
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
new file mode 100644
index 000000000..4c20301c6
--- /dev/null
+++ b/pkg/tcpip/network/ip_test.go
@@ -0,0 +1,655 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localIpv4Addr = "\x0a\x00\x00\x01"
+ localIpv4PrefixLen = 24
+ remoteIpv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ localIpv6PrefixLen = 120
+ remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+)
+
+// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
+// The former is used to pretend that it's a link endpoint so that we can
+// inspect packets written by the network endpoints. The latter is used to
+// pretend that it's the network stack so that it can inspect incoming packets
+// that have been handled by the network endpoints.
+//
+// Packets are checked by comparing their fields/values against the expected
+// values stored in the test object itself.
+type testObject struct {
+ t *testing.T
+ protocol tcpip.TransportProtocolNumber
+ contents []byte
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ v4 bool
+ typ stack.ControlType
+ extra uint32
+
+ dataCalls int
+ controlCalls int
+}
+
+// checkValues verifies that the transport protocol, data contents, src & dst
+// addresses of a packet match what's expected. If any field doesn't match, the
+// test fails.
+func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
+ v := vv.ToView()
+ if protocol != t.protocol {
+ t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
+ }
+
+ if srcAddr != t.srcAddr {
+ t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
+ }
+
+ if dstAddr != t.dstAddr {
+ t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
+ }
+
+ if len(v) != len(t.contents) {
+ t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
+ }
+
+ for i := range t.contents {
+ if t.contents[i] != v[i] {
+ t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
+ }
+ }
+}
+
+// DeliverTransportPacket is called by network endpoints after parsing incoming
+// packets. This is used by the test object to verify that the results of the
+// parsing are expected.
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) {
+ t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
+ t.dataCalls++
+}
+
+// DeliverTransportControlPacket is called by network endpoints after parsing
+// incoming control (ICMP) packets. This is used by the test object to verify
+// that the results of the parsing are expected.
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+ t.checkValues(trans, pkt.Data, remote, local)
+ if typ != t.typ {
+ t.t.Errorf("typ = %v, want %v", typ, t.typ)
+ }
+ if extra != t.extra {
+ t.t.Errorf("extra = %v, want %v", extra, t.extra)
+ }
+ t.controlCalls++
+}
+
+// Attach is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (*testObject) IsAttached() bool {
+ return true
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
+// matches the linux loopback MTU.
+func (*testObject) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testObject) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
+// WritePacket is called by network endpoints after producing a packet and
+// writing it to the link endpoint. This is used by the test object to verify
+// that the produced packet is as expected.
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+ var prot tcpip.TransportProtocolNumber
+ var srcAddr tcpip.Address
+ var dstAddr tcpip.Address
+
+ if t.v4 {
+ h := header.IPv4(pkt.Header.View())
+ prot = tcpip.TransportProtocolNumber(h.Protocol())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+
+ } else {
+ h := header.IPv6(pkt.Header.View())
+ prot = tcpip.TransportProtocolNumber(h.NextHeader())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+ }
+ t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+ s.CreateNIC(1, loopback.New())
+ s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ Gateway: ipv4Gateway,
+ NIC: 1,
+ }})
+
+ return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
+}
+
+func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+ s.CreateNIC(1, loopback.New())
+ s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: ipv6Gateway,
+ NIC: 1,
+ }})
+
+ return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+}
+
+func buildDummyStack() *stack.Stack {
+ return stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+}
+
+func TestIPv4Send(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = localIpv4Addr
+ o.dstAddr = remoteIpv4Addr
+ o.contents = payload
+
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv4Receive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = view[header.IPv4MinimumSize:totalLen]
+
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: view.ToVectorisedView(),
+ })
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv4ReceiveControl(t *testing.T) {
+ const mtu = 0xbeef - header.IPv4MinimumSize
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset uint16
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
+ {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8},
+ {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
+ }
+ r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv4 header.
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(view) - c.trunc),
+ TTL: 20,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: "\x0a\x00\x00\xbb",
+ DstAddr: localIpv4Addr,
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
+ icmp.SetType(header.ICMPv4DstUnreachable)
+ icmp.SetCode(c.code)
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
+
+ // Create the inner IPv4 header.
+ ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: 100,
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: c.fragmentOffset,
+ SrcAddr: localIpv4Addr,
+ DstAddr: remoteIpv4Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv4 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view[:len(view)-c.trunc].ToVectorisedView()
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: vv,
+ })
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
+
+func TestIPv4FragmentationReceive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 24
+
+ frag1 := buffer.NewView(totalLen)
+ ip1 := header.IPv4(frag1)
+ ip1.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 0,
+ Flags: header.IPv4FlagMoreFragments,
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag1[i] = uint8(i)
+ }
+
+ frag2 := buffer.NewView(totalLen)
+ ip2 := header.IPv4(frag2)
+ ip2.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 24,
+ SrcAddr: remoteIpv4Addr,
+ DstAddr: localIpv4Addr,
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag2[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv4Addr
+ o.dstAddr = localIpv4Addr
+ o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+
+ r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+
+ // Send first segment.
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: frag1.ToVectorisedView(),
+ })
+ if o.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ }
+
+ // Send second segment.
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: frag2.ToVectorisedView(),
+ })
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6Send(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = localIpv6Addr
+ o.dstAddr = remoteIpv6Addr
+ o.contents = payload
+
+ r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv6Receive(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv6MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: remoteIpv6Addr,
+ DstAddr: localIpv6Addr,
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv6Addr
+ o.dstAddr = localIpv6Addr
+ o.contents = view[header.IPv6MinimumSize:totalLen]
+
+ r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ if err != nil {
+ t.Fatalf("could not find route: %v", err)
+ }
+
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: view.ToVectorisedView(),
+ })
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6ReceiveControl(t *testing.T) {
+ newUint16 := func(v uint16) *uint16 { return &v }
+
+ const mtu = 0xffff
+ const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset *uint16
+ typ header.ICMPv6Type
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
+ {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
+ {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
+ {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
+ }
+ r, err := buildIPv6Route(
+ localIpv6Addr,
+ "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ defer ep.Close()
+
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
+ if c.fragmentOffset != nil {
+ dataOffset += header.IPv6FragmentHeaderSize
+ }
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv6 header.
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIpv6Addr,
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
+ icmp.SetType(c.typ)
+ icmp.SetCode(c.code)
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
+
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: localIpv6Addr,
+ DstAddr: remoteIpv6Addr,
+ })
+
+ // Build the fragmentation header if needed.
+ if c.fragmentOffset != nil {
+ ip.SetNextHeader(header.IPv6FragmentHeader)
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
+ frag.Encode(&header.IPv6FragmentFields{
+ NextHeader: 10,
+ FragmentOffset: *c.fragmentOffset,
+ M: true,
+ Identification: 0x12345678,
+ })
+ }
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv6 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = remoteIpv6Addr
+ o.dstAddr = localIpv6Addr
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ // Set ICMPv6 checksum.
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: view[:len(view)-c.trunc].ToVectorisedView(),
+ })
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
new file mode 100644
index 000000000..880ea7de2
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -0,0 +1,38 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ipv4",
+ srcs = [
+ "icmp.go",
+ "ipv4.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "ipv4_test",
+ size = "small",
+ srcs = ["ipv4_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
new file mode 100644
index 000000000..4cbefe5ab
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv4(h)
+
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if hdr.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ hlen := int(hdr.HeaderLength())
+ if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ pkt.Data.TrimFront(hlen)
+ p := hdr.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
+ stats := r.Stats()
+ received := stats.ICMP.V4PacketsReceived
+ v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv4(v)
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ received.Echo.Increment()
+
+ // Only send a reply if the checksum is valid.
+ wantChecksum := h.Checksum()
+ // Reset the checksum field to 0 to can calculate the proper
+ // checksum. We'll have to reset this before we hand the packet
+ // off.
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ if gotChecksum != wantChecksum {
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ received.Invalid.Increment()
+ return
+ }
+
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{
+ Data: pkt.Data.Clone(nil),
+ NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
+ })
+
+ vv := pkt.Data.Clone(nil)
+ vv.TrimFront(header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv4EchoReply)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
+ sent := stats.ICMP.V4PacketsSent
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ Data: vv,
+ TransportHeader: buffer.View(pkt),
+ }); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv4EchoReply:
+ received.EchoReply.Increment()
+
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+
+ case header.ICMPv4DstUnreachable:
+ received.DstUnreachable.Increment()
+
+ pkt.Data.TrimFront(header.ICMPv4MinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(h.MTU())
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ }
+
+ case header.ICMPv4SrcQuench:
+ received.SrcQuench.Increment()
+
+ case header.ICMPv4Redirect:
+ received.Redirect.Increment()
+
+ case header.ICMPv4TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv4ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv4Timestamp:
+ received.Timestamp.Increment()
+
+ case header.ICMPv4TimestampReply:
+ received.TimestampReply.Increment()
+
+ case header.ICMPv4InfoRequest:
+ received.InfoRequest.Increment()
+
+ case header.ICMPv4InfoReply:
+ received.InfoReply.Increment()
+
+ default:
+ received.Invalid.Increment()
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
new file mode 100644
index 000000000..9db42b2a4
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -0,0 +1,598 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv4 contains the implementation of the ipv4 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv4
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolNumber is the ipv4 protocol number.
+ ProtocolNumber = header.IPv4ProtocolNumber
+
+ // MaxTotalSize is maximum size that can be encoded in the 16-bit
+ // TotalLength field of the ipv4 header.
+ MaxTotalSize = 0xffff
+
+ // DefaultTTL is the default time-to-live value for this endpoint.
+ DefaultTTL = 64
+
+ // buckets is the number of identifier buckets.
+ buckets = 2048
+)
+
+type endpoint struct {
+ nicID tcpip.NICID
+ id stack.NetworkEndpointID
+ prefixLen int
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
+ protocol *protocol
+ stack *stack.Stack
+}
+
+// NewEndpoint creates a new ipv4 endpoint.
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ e := &endpoint{
+ nicID: nicID,
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
+ stack: st,
+ }
+
+ return e, nil
+}
+
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// ID returns the ipv4 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
+func (e *endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
+// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
+// write. It assumes that the IP header is entirely in pkt.Header but does not
+// assume that only the IP header is in pkt.Header. It assumes that the input
+// packet's stated length matches the length of the header+payload. mtu
+// includes the IP header and options. This does not support the DontFragment
+// IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error {
+ // This packet is too big, it needs to be fragmented.
+ ip := header.IPv4(pkt.Header.View())
+ flags := ip.Flags()
+
+ // Update mtu to take into account the header, which will exist in all
+ // fragments anyway.
+ innerMTU := mtu - int(ip.HeaderLength())
+
+ // Round the MTU down to align to 8 bytes. Then calculate the number of
+ // fragments. Calculate fragment sizes as in RFC791.
+ innerMTU &^= 7
+ n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
+
+ outerMTU := innerMTU + int(ip.HeaderLength())
+ offset := ip.FragmentOffset()
+ originalAvailableLength := pkt.Header.AvailableLength()
+ for i := 0; i < n; i++ {
+ // Where possible, the first fragment that is sent has the same
+ // pkt.Header.UsedLength() as the input packet. The link-layer
+ // endpoint may depend on this for looking at, eg, L4 headers.
+ h := ip
+ if i > 0 {
+ pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
+ copy(h, ip[:ip.HeaderLength()])
+ }
+ if i != n-1 {
+ h.SetTotalLength(uint16(outerMTU))
+ h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
+ } else {
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
+ h.SetFlagsFragmentOffset(flags, offset)
+ }
+ h.SetChecksum(0)
+ h.SetChecksum(^h.CalculateChecksum())
+ offset += uint16(innerMTU)
+ if i > 0 {
+ newPayload := pkt.Data.Clone(nil)
+ newPayload.CapLength(innerMTU)
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ pkt.Data.TrimFront(newPayload.Size())
+ continue
+ }
+ // Special handling for the first fragment because it comes
+ // from the header.
+ if outerMTU >= pkt.Header.UsedLength() {
+ // This fragment can fit all of pkt.Header and possibly
+ // some of pkt.Data, too.
+ newPayload := pkt.Data.Clone(nil)
+ newPayloadLength := outerMTU - pkt.Header.UsedLength()
+ newPayload.CapLength(newPayloadLength)
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ pkt.Data.TrimFront(newPayloadLength)
+ } else {
+ // The fragment is too small to fit all of pkt.Header.
+ startOfHdr := pkt.Header
+ startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
+ emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ Header: startOfHdr,
+ Data: emptyVV,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ // Add the unused bytes of pkt.Header into the pkt.Data
+ // that remains to be sent.
+ restOfHdr := pkt.Header.View()[outerMTU:]
+ tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
+ tmp.Append(pkt.Data)
+ pkt.Data = tmp
+ }
+ }
+ return nil
+}
+
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ length := uint16(hdr.UsedLength() + payloadSize)
+ id := uint32(0)
+ if length > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: length,
+ ID: uint16(id),
+ TTL: params.TTL,
+ TOS: params.TOS,
+ Protocol: uint8(params.Protocol),
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok {
+ // iptables is telling us to drop the packet.
+ return nil
+ }
+
+ if pkt.NatDone {
+ // If the packet is manipulated as per NAT Ouput rules, handle packet
+ // based on destination address and do not send the packet to link layer.
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ if err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ packet := stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
+ ep.HandlePacket(&route, packet)
+ return nil
+ }
+ }
+
+ if r.Loop&stack.PacketLoop != 0 {
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ loopedR := r.MakeLoopedRoute()
+
+ e.HandlePacket(&loopedR, stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
+ loopedR.Release()
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+ if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ return nil
+}
+
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
+ panic("multiple packets in local loop")
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return pkts.Len(), nil
+ }
+
+ for pkt := pkts.Front(); pkt != nil; {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+ pkt = pkt.Next()
+ }
+
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+
+ // Slow Path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ if err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ packet := stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
+ ep.HandlePacket(&route, packet)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ n++
+ }
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, nil
+}
+
+// WriteHeaderIncludedPacket writes a packet already containing a network
+// header through the given route.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required
+ // checks.
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ ip := header.IPv4(h)
+ if !ip.IsValid(pkt.Data.Size()) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ // Always set the total length.
+ ip.SetTotalLength(uint16(pkt.Data.Size()))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination,
+ // it will be part of the route.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Set the packet ID when zero.
+ if ip.ID() == 0 {
+ id := uint32(0)
+ if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
+ }
+ ip.SetID(uint16(id))
+ }
+
+ // Always set the checksum.
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ if r.Loop&stack.PacketLoop != 0 {
+ e.HandlePacket(r, pkt.Clone())
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
+
+ ip = ip[:ip.HeaderLength()]
+ pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
+ pkt.Data.TrimFront(int(ip.HeaderLength()))
+ return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+}
+
+// HandlePacket is called by the link layer when new ipv4 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
+ headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ h := header.IPv4(headerView)
+ if !h.IsValid(pkt.Data.Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ pkt.NetworkHeader = headerView[:h.HeaderLength()]
+
+ hlen := int(h.HeaderLength())
+ tlen := int(h.TotalLength())
+ pkt.Data.TrimFront(hlen)
+ pkt.Data.CapLength(tlen - hlen)
+
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+
+ more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
+ if more || h.FragmentOffset() != 0 {
+ if pkt.Data.Size() == 0 {
+ // Drop the packet as it's marked as a fragment but has
+ // no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ // The packet is a fragment, let's try to reassemble it.
+ last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < h.FragmentOffset() {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ var ready bool
+ var err error
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ if !ready {
+ return
+ }
+ }
+ p := h.TransportProtocol()
+ if p == header.ICMPv4ProtocolNumber {
+ headerView.CapLength(hlen)
+ e.handleICMP(r, pkt)
+ return
+ }
+ r.Stats().IP.PacketsDelivered.Increment()
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (e *endpoint) Close() {}
+
+type protocol struct {
+ ids []uint32
+ hashIV uint32
+
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
+}
+
+// Number returns the ipv4 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv4 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv4MinimumSize
+}
+
+// DefaultPrefixLen returns the IPv4 default prefix length.
+func (p *protocol) DefaultPrefixLen() int {
+ return header.IPv4AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv4(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address, the transport protocol number, and a random initial
+// value (generated once on initialization) to generate the hash.
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ t := r.LocalAddress
+ a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = r.RemoteAddress
+ b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return hash.Hash3Words(a, b, uint32(protocol), hashIV)
+}
+
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
+
+ // Randomly initialize hashIV and the ids.
+ r := hash.RandN32(1 + buckets)
+ for i := range ids {
+ ids[i] = r[i]
+ }
+ hashIV := r[buckets]
+
+ return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
new file mode 100644
index 000000000..5a864d832
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -0,0 +1,475 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "bytes"
+ "encoding/hex"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestExcludeBroadcast(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+
+ const defaultMTU = 65536
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ }})
+
+ randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
+
+ var wq waiter.Queue
+ t.Run("WithoutPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Cannot connect using a broadcast address as the source.
+ if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
+
+ // However, we can bind to a broadcast address to listen.
+ if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil {
+ t.Errorf("Bind failed: %v", err)
+ }
+ })
+
+ t.Run("WithPrimaryAddress", func(t *testing.T) {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ // Add a valid primary endpoint address, now we can connect.
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ if err := ep.Connect(randomAddr); err != nil {
+ t.Errorf("Connect failed: %v", err)
+ }
+ })
+}
+
+// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
+// data should already be in the header before WritePacket. extraLength
+// indicates how much extra space should be in the header. The payload is made
+// from many Views of the sizes listed in viewSizes.
+func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
+ hdr := buffer.NewPrependable(hdrLength + extraLength)
+ hdr.Prepend(hdrLength)
+ rand.Read(hdr.View())
+
+ var views []buffer.View
+ totalLength := 0
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ rand.Read(newView)
+ views = append(views, newView)
+ totalLength += s
+ }
+ payload := buffer.NewVectorisedView(totalLength, views)
+ return hdr, payload
+}
+
+// comparePayloads compared the contents of all the packets against the contents
+// of the source packet.
+func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) {
+ t.Helper()
+ // Make a complete array of the sourcePacketInfo packet.
+ source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
+ source = append(source, sourcePacketInfo.Header.View()...)
+ source = append(source, sourcePacketInfo.Data.ToView()...)
+
+ // Make a copy of the IP header, which will be modified in some fields to make
+ // an expected header.
+ sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetFlagsFragmentOffset(0, 0)
+ sourceCopy.SetTotalLength(0)
+ var offset uint16
+ // Build up an array of the bytes sent.
+ var reassembledPayload []byte
+ for i, packet := range packets {
+ // Confirm that the packet is valid.
+ allBytes := packet.Header.View().ToVectorisedView()
+ allBytes.Append(packet.Data)
+ ip := header.IPv4(allBytes.ToView())
+ if !ip.IsValid(len(ip)) {
+ t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ }
+ if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
+ t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ }
+ if got, want := len(ip), int(mtu); got > want {
+ t.Errorf("fragment is too large, got %d want %d", got, want)
+ }
+ if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
+ t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ }
+ if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
+ t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ }
+ if i < len(packets)-1 {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ } else {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ }
+ reassembledPayload = append(reassembledPayload, ip.Payload()...)
+ offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ // Clear out the checksum and length from the ip because we can't compare
+ // it.
+ sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+ if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
+ t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ }
+ }
+ expected := source[source.HeaderLength():]
+ if !bytes.Equal(reassembledPayload, expected) {
+ t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+ }
+}
+
+type errorChannel struct {
+ *channel.Endpoint
+ Ch chan stack.PacketBuffer
+ packetCollectorErrors []*tcpip.Error
+}
+
+// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
+// will return successive errors from packetCollectorErrors until the list is
+// empty and then return nil each time.
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
+ Ch: make(chan stack.PacketBuffer, size),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+}
+
+// Drain removes all outbound packets from the channel and counts them.
+func (e *errorChannel) Drain() int {
+ c := 0
+ for {
+ select {
+ case <-e.Ch:
+ c++
+ default:
+ return c
+ }
+ }
+}
+
+// WritePacket stores outbound packets into the channel.
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+ select {
+ case e.Ch <- pkt:
+ default:
+ }
+
+ nextError := (*tcpip.Error)(nil)
+ if len(e.packetCollectorErrors) > 0 {
+ nextError = e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ }
+ return nextError
+}
+
+type context struct {
+ stack.Route
+ linkEP *errorChannel
+}
+
+func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+ // Make the packet and write it.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return context{
+ Route: r,
+ linkEP: ep,
+ }
+}
+
+func TestFragmentation(t *testing.T) {
+ var manyPayloadViewsSizes [1000]int
+ for i := range manyPayloadViewsSizes {
+ manyPayloadViewsSizes[i] = 7
+ }
+ fragTests := []struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ hdrLength int
+ extraLength int
+ payloadViewsSizes []int
+ expectedFrags int
+ }{
+ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ source := stack.PacketBuffer{
+ Header: hdr,
+ // Save the source payload because WritePacket will modify it.
+ Data: payload.Clone(nil),
+ }
+ c := buildContext(t, nil, ft.mtu)
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
+ if err != nil {
+ t.Errorf("err got %v, want %v", err, nil)
+ }
+
+ var results []stack.PacketBuffer
+ L:
+ for {
+ select {
+ case pi := <-c.linkEP.Ch:
+ results = append(results, pi)
+ default:
+ break L
+ }
+ }
+
+ if got, want := len(results), ft.expectedFrags; got != want {
+ t.Errorf("len(result) got %d, want %d", got, want)
+ }
+ if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ }
+ compareFragments(t, results, source, ft.mtu)
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from write packet
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ fragTests := []struct {
+ description string
+ mtu uint32
+ hdrLength int
+ payloadViewsSizes []int
+ packetCollectorErrors []*tcpip.Error
+ }{
+ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
+ c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
+ for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
+ if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
+ t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
+ }
+ }
+ // We only need to check that last error because all the ones before are
+ // nil.
+ if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
+ t.Errorf("err got %v, want %v", got, want)
+ }
+ if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ }
+ })
+ }
+}
+
+func TestInvalidFragments(t *testing.T) {
+ // These packets have both IHL and TotalLength set to 0.
+ testCases := []struct {
+ name string
+ packets [][]byte
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ "ihl_totallen_zero_valid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ "ihl_totallen_zero_invalid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ // Total Length of 37(20 bytes IP header + 17 bytes of
+ // payload)
+ // Frag Offset of 0x1ffe = 8190*8 = 65520
+ // Leading to the fragment end to be past 65535.
+ "ihl_totallen_valid_invalid_frag_offset_1",
+ [][]byte{
+ {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ // The following 3 tests were found by running a fuzzer and were
+ // triggering a panic in the IPv4 reassembler code.
+ {
+ "ihl_less_than_ipv4_minimum_size_1",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_2",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_3",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "fragment_with_short_total_len_extra_payload",
+ [][]byte{
+ {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ {
+ "multiple_fragments_with_more_fragments_set_to_false",
+ [][]byte{
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ },
+ 1,
+ 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ const nicID tcpip.NICID = 42
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ },
+ })
+
+ var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
+ var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
+ ep := channel.New(10, 1500, linkAddr)
+ s.CreateNIC(nicID, sniffer.New(ep))
+
+ for _, pkt := range tc.packets {
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ })
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
new file mode 100644
index 000000000..3f71fc520
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -0,0 +1,44 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ipv6",
+ srcs = [
+ "icmp.go",
+ "ipv6.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "ipv6_test",
+ size = "small",
+ srcs = [
+ "icmp_test.go",
+ "ipv6_test.go",
+ "ndp_test.go",
+ ],
+ library = ":ipv6",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
new file mode 100644
index 000000000..bdf3a0d25
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,546 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv6(h)
+
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if hdr.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
+ p := hdr.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
+ if !ok {
+ return
+ }
+ fragHdr := header.IPv6Fragment(f)
+ if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
+ p = fragHdr.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ received := stats.V6PacketsReceived
+ v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv6(v)
+ iph := header.IPv6(netHeader)
+
+ // Validate ICMPv6 checksum before processing the packet.
+ //
+ // This copy is used as extra payload during the checksum calculation.
+ payload := pkt.Data.Clone(nil)
+ payload.TrimFront(len(h))
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ received.Invalid.Increment()
+ return
+ }
+
+ isNDPValid := func() bool {
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
+ //
+ // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the
+ // packet includes a fragmentation header.
+ return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0
+ }
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ received.PacketTooBig.Increment()
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := header.ICMPv6(hdr).MTU()
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+
+ case header.ICMPv6DstUnreachable:
+ received.DstUnreachable.Increment()
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
+ }
+
+ case header.ICMPv6NeighborSolicit:
+ received.NeighborSolicit.Increment()
+ if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the neighbor solicitation, so
+ // payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
+ // NDP messages cannot be fragmented. Also note that in the common case NDP
+ // datagrams are very small and ToView() will not incur allocations.
+ ns := header.NDPNeighborSolicit(payload.ToView())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NS option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ targetAddr := ns.TargetAddress()
+ s := r.Stack()
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now, drop this packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ } else if isTentative {
+ // If the target address is tentative and the source of the packet is a
+ // unicast (specified) address, then the source of the packet is
+ // attempting to perform address resolution on the target. In this case,
+ // the solicitation is silently ignored, as per RFC 4862 section 5.4.3.
+ //
+ // If the target address is tentative and the source of the packet is the
+ // unspecified address (::), then we know another node is also performing
+ // DAD for the same address (since the target address is tentative for us,
+ // we know we are also performing DAD on it). In this case we let the
+ // stack know so it can handle such a scenario and do nothing further with
+ // the NS.
+ if r.RemoteAddress == header.IPv6Any {
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ }
+
+ // Do not handle neighbor solicitations targeted to an address that is
+ // tentative on the NIC any further.
+ return
+ }
+
+ // At this point we know that the target address is not tentative on the NIC
+ // so the packet is processed as defined in RFC 4861, as per RFC 4862
+ // section 5.4.3.
+
+ // Is the NS targetting us?
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ return
+ }
+
+ // If the NS message contains the Source Link-Layer Address option, update
+ // the link address cache with the value of the option.
+ //
+ // TODO(b/148429853): Properly process the NS message and do Neighbor
+ // Unreachability Detection.
+ var sourceLinkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
+ }
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ // No RFCs define what to do when an NS message has multiple Source
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(sourceLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ sourceLinkAddr = opt.EthernetAddress()
+ }
+ }
+
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
+
+ // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, on link layers that have addresses this option MUST be
+ // included in multicast solicitations and SHOULD be included in unicast
+ // solicitations.
+ if len(sourceLinkAddr) == 0 {
+ if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ }
+ } else if unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ } else {
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
+ }
+
+ // ICMPv6 Neighbor Solicit messages are always sent to
+ // specially crafted IPv6 multicast addresses. As a result, the
+ // route we end up with here has as its LocalAddress such a
+ // multicast address. It would be nonsense to claim that our
+ // source address is a multicast address, so we manually set
+ // the source address to the target address requested in the
+ // solicit message. Since that requires mutating the route, we
+ // must first clone it.
+ r := r.Clone()
+ defer r.Release()
+ r.LocalAddress = targetAddr
+
+ // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
+ // the unspecified address, the node MUST set the Solicited flag to zero and
+ // multicast the advertisement to the all-nodes address.
+ solicited := true
+ if unspecifiedSource {
+ solicited = false
+ r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ }
+
+ // If the NS has a source link-layer option, use the link address it
+ // specifies as the remote link address for the response instead of the
+ // source link address of the packet.
+ //
+ // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
+ // address cache for the right destination link address instead of manually
+ // patching the route with the remote link address if one is specified in a
+ // Source Link-Layer Address option.
+ if len(sourceLinkAddr) != 0 {
+ r.RemoteLinkAddress = sourceLinkAddr
+ }
+
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na.SetSolicitedFlag(solicited)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ // RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+ //
+ // 7.1.2. Validation of Neighbor Advertisements
+ //
+ // The IP Hop Limit field has a value of 255, i.e., the packet
+ // could not possibly have been forwarded by a router.
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.NeighborAdvert.Increment()
+
+ case header.ICMPv6NeighborAdvert:
+ received.NeighborAdvert.Increment()
+ if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the neighbor advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ na := header.NDPNeighborAdvert(payload.ToView())
+ it, err := na.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ targetAddr := na.TargetAddress()
+ stack := r.Stack()
+
+ if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now short-circuit this packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ } else if isTentative {
+ // We just got an NA from a node that owns an address we are performing
+ // DAD on, implying the address is not unique. In this case we let the
+ // stack know so it can handle such a scenario and do nothing furthur with
+ // the NDP NA.
+ stack.DupTentativeAddrDetected(e.nicID, targetAddr)
+ return
+ }
+
+ // At this point we know that the target address is not tentative on the
+ // NIC. However, the target address may still be assigned to the NIC but not
+ // tentative (it could be permanent). Such a scenario is beyond the scope of
+ // RFC 4862. As such, we simply ignore such a scenario for now and proceed
+ // as normal.
+ //
+ // TODO(b/143147598): Handle the scenario described above. Also inform the
+ // netstack integration that a duplicate address was detected outside of
+ // DAD.
+
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
+ //
+ // TODO(b/148429853): Properly process the NA message and do Neighbor
+ // Unreachability Detection.
+ var targetLinkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
+ }
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPTargetLinkLayerAddressOption:
+ // No RFCs define what to do when an NA message has multiple Target
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(targetLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ targetLinkAddr = opt.EthernetAddress()
+ }
+ }
+
+ if len(targetLinkAddr) != 0 {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ }
+
+ case header.ICMPv6EchoRequest:
+ received.EchoRequest.Increment()
+ icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(packet, icmpHdr)
+ packet.SetType(header.ICMPv6EchoReply)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ Header: hdr,
+ Data: pkt.Data,
+ }); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv6EchoReply:
+ received.EchoReply.Increment()
+ if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
+
+ case header.ICMPv6TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv6ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv6RouterSolicit:
+ received.RouterSolicit.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
+
+ // Is the NDP payload of sufficient size to hold a Router
+ // Advertisement?
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
+ //
+ // Validate the RA as per RFC 4861 section 6.1.2.
+ //
+
+ // Is the IP Source Address a link-local address?
+ if !header.IsV6LinkLocalAddress(routerAddr) {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the router advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ ra := header.NDPRouterAdvert(payload.ToView())
+ opts := ra.Options()
+
+ // Are options valid as per the wire format?
+ if _, err := opts.Iter(true); err != nil {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ //
+ // At this point, we have a valid Router Advertisement, as far
+ // as RFC 4861 section 6.1.2 is concerned.
+ //
+
+ // Tell the NIC to handle the RA.
+ stack := r.Stack()
+ rxNICID := r.NICID()
+ stack.HandleNDPRA(rxNICID, routerAddr, ra)
+
+ case header.ICMPv6RedirectMsg:
+ received.RedirectMsg.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ default:
+ received.Invalid.Increment()
+ }
+}
+
+const (
+ ndpSolicitedFlag = 1 << 6
+ ndpOverrideFlag = 1 << 5
+
+ ndpOptSrcLinkAddr = 1
+ ndpOptDstLinkAddr = 2
+
+ icmpV6FlagOffset = 4
+ icmpV6OptOffset = 24
+ icmpV6LengthOffset = 25
+)
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ snaddr := header.SolicitedNodeAddr(addr)
+
+ // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
+ // route here. Note, we would need the nicID to do this properly so the right
+ // NIC (associated to linkEP) is used to send the NDP NS message.
+ r := &stack.Route{
+ LocalAddress: localAddr,
+ RemoteAddress: snaddr,
+ RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ }
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ copy(pkt[icmpV6OptOffset-len(addr):], addr)
+ pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ length := uint16(hdr.UsedLength())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ // TODO(stijlist): count this in ICMP stats.
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ Header: hdr,
+ })
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return tcpip.LinkAddress([]byte(nil)), false
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
new file mode 100644
index 000000000..d412ff688
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -0,0 +1,960 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "context"
+ "reflect"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+)
+
+var (
+ lladdr0 = header.LinkLocalAddr(linkAddr0)
+ lladdr1 = header.LinkLocalAddr(linkAddr1)
+)
+
+type stubLinkEndpoint struct {
+ stack.LinkEndpoint
+}
+
+func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+type stubDispatcher struct {
+ stack.TransportDispatcher
+}
+
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) {
+}
+
+type stubLinkAddressCache struct {
+ stack.LinkAddressCache
+}
+
+func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID {
+ return 0
+}
+
+func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
+}
+
+func TestICMPCounts(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+ {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+ defer r.Release()
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ }{
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(hdr buffer.Prependable) {
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(&r, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ for _, typ := range types {
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+
+ handleIPv6Payload(hdr)
+ }
+
+ // Construct an empty ICMP packet so that
+ // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
+ handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
+
+ icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
+ if got, want := s.Value(), uint64(1); got != want {
+ t.Errorf("got %s = %d, want = %d", name, got, want)
+ }
+ })
+ if t.Failed() {
+ t.Logf("stats:\n%+v", s.Stats())
+ }
+}
+
+func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
+ t := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ v := v.Field(i)
+ if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+ f(t.Field(i).Name, s)
+ } else {
+ visitStats(v, f)
+ }
+ }
+}
+
+type testContext struct {
+ s0 *stack.Stack
+ s1 *stack.Stack
+
+ linkEP0 *channel.Endpoint
+ linkEP1 *channel.Endpoint
+}
+
+type endpointWithResolutionCapability struct {
+ stack.LinkEndpoint
+}
+
+func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
+ return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
+}
+
+func newTestContext(t *testing.T) *testContext {
+ c := &testContext{
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ }
+
+ const defaultMTU = 65536
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
+ if testing.Verbose() {
+ wrappedEP0 = sniffer.New(wrappedEP0)
+ }
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
+ t.Fatalf("CreateNIC s0: %v", err)
+ }
+ if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress lladdr0: %v", err)
+ }
+
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ t.Fatalf("AddAddress lladdr1: %v", err)
+ }
+
+ subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.s0.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet0,
+ NIC: 1,
+ }},
+ )
+ subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.s1.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet1,
+ NIC: 1,
+ }},
+ )
+
+ return c
+}
+
+func (c *testContext) cleanup() {
+ c.linkEP0.Close()
+ c.linkEP1.Close()
+}
+
+type routeArgs struct {
+ src, dst *channel.Endpoint
+ typ header.ICMPv6Type
+ remoteLinkAddr tcpip.LinkAddress
+}
+
+func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
+ t.Helper()
+
+ pi, _ := args.src.ReadContext(context.Background())
+
+ {
+ views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
+ size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
+ vv := buffer.NewVectorisedView(size, views)
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if pi.Proto != ProtocolNumber {
+ t.Errorf("unexpected protocol number %d", pi.Proto)
+ return
+ }
+
+ if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ }
+
+ ipv6 := header.IPv6(pi.Pkt.Header.View())
+ transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
+ if transProto != header.ICMPv6ProtocolNumber {
+ t.Errorf("unexpected transport protocol number %d", transProto)
+ return
+ }
+ icmpv6 := header.ICMPv6(ipv6.Payload())
+ if got, want := icmpv6.Type(), args.typ; got != want {
+ t.Errorf("got ICMPv6 type = %d, want = %d", got, want)
+ return
+ }
+ if fn != nil {
+ fn(t, icmpv6)
+ }
+}
+
+func TestLinkResolution(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+
+ r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+ defer r.Release()
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ payload := tcpip.SlicePayload(hdr.View())
+
+ // We can't send our payload directly over the route because that
+ // doesn't provoke NDP discovery.
+ var wq waiter.Queue
+ ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ for {
+ _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
+ if resCh != nil {
+ if err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
+ }
+ for _, args := range []routeArgs{
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
+ {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
+ } {
+ routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
+ if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
+ t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
+ }
+ })
+ }
+ <-resCh
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Write(_) = _, _, %s", err)
+ }
+ break
+ }
+
+ for _, args := range []routeArgs{
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
+ {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
+ } {
+ routeICMPv6Packet(t, args, nil)
+ }
+}
+
+func TestICMPChecksumValidationSimple(t *testing.T) {
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ name: "DstUnreachable",
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ },
+ {
+ name: "PacketTooBig",
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ },
+ {
+ name: "TimeExceeded",
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ },
+ {
+ name: "ParamProblem",
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ },
+ {
+ name: "EchoRequest",
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ },
+ {
+ name: "EchoReply",
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ },
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(typ.size + extraDataLen),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayload(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ icmpSize := size + payloadSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(typ)
+ payloadFn(pkt.Payload())
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
+ pkt := header.ICMPv6(hdr.Prepend(size))
+ pkt.SetType(typ)
+
+ payload := buffer.NewView(payloadSize)
+ payloadFn(payload)
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(size + payloadSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
new file mode 100644
index 000000000..daf1fcbc6
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -0,0 +1,521 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv6 contains the implementation of the ipv6 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv6
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolNumber is the ipv6 protocol number.
+ ProtocolNumber = header.IPv6ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // PayloadLength field of the ipv6 header.
+ maxPayloadSize = 0xffff
+
+ // DefaultTTL is the default hop limit for IPv6 Packets egressed by
+ // Netstack.
+ DefaultTTL = 64
+)
+
+type endpoint struct {
+ nicID tcpip.NICID
+ id stack.NetworkEndpointID
+ prefixLen int
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+ dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
+ protocol *protocol
+}
+
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// ID returns the ipv6 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
+func (e *endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 {
+ length := uint16(hdr.UsedLength() + payloadSize)
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(params.Protocol),
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+
+ if r.Loop&stack.PacketLoop != 0 {
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ loopedR := r.MakeLoopedRoute()
+
+ e.HandlePacket(&loopedR, stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
+ loopedR.Release()
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
+ return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
+ panic("not implemented")
+ }
+ if r.Loop&stack.PacketOut == 0 {
+ return pkts.Len(), nil
+ }
+
+ for pb := pkts.Front(); pb != nil; pb = pb.Next() {
+ ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
+ pb.NetworkHeader = buffer.View(ip)
+ }
+
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
+// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
+// supported by IPv6.
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+ // TODO(b/146666412): Support IPv6 header-included packets.
+ return tcpip.ErrNotSupported
+}
+
+// HandlePacket is called by the link layer when new ipv6 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
+ headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ h := header.IPv6(headerView)
+ if !h.IsValid(pkt.Data.Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
+ pkt.Data.CapLength(int(h.PayloadLength()))
+
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+ hasFragmentHeader := false
+
+ for firstHeader := true; ; firstHeader = false {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6HopByHopOptionsExtHdr:
+ // As per RFC 8200 section 4.1, the Hop By Hop extension header is
+ // restricted to appear immediately after an IPv6 fixed header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
+ // (unrecognized next header) error in response to an extension header's
+ // Next Header field with the Hop By Hop extension header identifier.
+ if !firstHeader {
+ return
+ }
+
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Hop By Hop extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RoutingExtHdr:
+ // As per RFC 8200 section 4.4, if a node encounters a routing header with
+ // an unrecognized routing type value, with a non-zero Segments Left
+ // value, the node must discard the packet and send an ICMP Parameter
+ // Problem, Code 0. If the Segments Left is 0, the node must ignore the
+ // Routing extension header and process the next header in the packet.
+ //
+ // Note, the stack does not yet handle any type of routing extension
+ // header, so we just make sure Segments Left is zero before processing
+ // the next extension header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
+ // unrecognized routing types with a non-zero Segments Left value.
+ if extHdr.SegmentsLeft() != 0 {
+ return
+ }
+
+ case header.IPv6FragmentExtHdr:
+ hasFragmentHeader = true
+
+ fragmentOffset := extHdr.FragmentOffset()
+ more := extHdr.More()
+ if !more && fragmentOffset == 0 {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
+ // Don't consume the iterator if we have the first fragment because we
+ // will use it to validate that the first fragment holds the upper layer
+ // header.
+ rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
+
+ if fragmentOffset == 0 {
+ // Check that the iterator ends with a raw payload as the first fragment
+ // should include all headers up to and including any upper layer
+ // headers, as per RFC 8200 section 4.5; only upper layer data
+ // (non-headers) should follow the fragment extension header.
+ var lastHdr header.IPv6PayloadHeader
+
+ for {
+ it, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ lastHdr = it
+ }
+
+ // If the last header is a raw header, then the last portion of the IPv6
+ // payload is not a known IPv6 extension header. Note, this does not
+ // mean that the last portion is an upper layer header or not an
+ // extension header because:
+ // 1) we do not yet support all extension headers
+ // 2) we do not validate the upper layer header before reassembling.
+ //
+ // This check makes sure that a known IPv6 extension header is not
+ // present after the Fragment extension header in a non-initial
+ // fragment.
+ //
+ // TODO(#2196): Support IPv6 Authentication and Encapsulated
+ // Security Payload extension headers.
+ // TODO(#2333): Validate that the upper layer header is valid.
+ switch lastHdr.(type) {
+ case header.IPv6RawPayloadHeader:
+ default:
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ }
+
+ fragmentPayloadLen := rawPayload.Buf.Size()
+ if fragmentPayloadLen == 0 {
+ // Drop the packet as it's marked as a fragment but has no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // The packet is a fragment, let's try to reassemble it.
+ start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ last := start + uint16(fragmentPayloadLen) - 1
+
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < start {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ var ready bool
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ if ready {
+ // We create a new iterator with the reassembled packet because we could
+ // have more extension headers in the reassembled payload, as per RFC
+ // 8200 section 4.5.
+ it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ }
+
+ case header.IPv6DestinationOptionsExtHdr:
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Destination extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // If the last header in the payload isn't a known IPv6 extension header,
+ // handle it as if it is transport layer data.
+ pkt.Data = extHdr.Buf
+
+ if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, headerView, pkt, hasFragmentHeader)
+ } else {
+ r.Stats().IP.PacketsDelivered.Increment()
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ }
+
+ default:
+ // If we receive a packet for an extension header we do not yet handle,
+ // drop the packet for now.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ r.Stats().UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+ }
+}
+
+// Close cleans up resources associated with the endpoint.
+func (*endpoint) Close() {}
+
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
+type protocol struct {
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
+}
+
+// Number returns the ipv6 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv6 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen returns the IPv6 default prefix length.
+func (p *protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv6 endpoint.
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &endpoint{
+ nicID: nicID,
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ linkAddrCache: linkAddrCache,
+ dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
+}
+
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{defaultTTL: DefaultTTL}
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..841a0cb7a
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,1265 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+ addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+
+ // Tests use the extension header identifier values as uint8 instead of
+ // header.IPv6ExtensionHeaderIdentifier.
+ hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
+ routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
+ fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
+ destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
+ noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as we haven't added those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(nicID, addr2); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
+ }
+
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ // Make sure addr3's endpoint does not get removed from the NIC by
+ // incrementing its reference count with a route.
+ r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ if err := s.RemoveAddress(nicID, addr3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as both of them got removed, even though a route using
+ // addr3 exists.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
+
+func TestReceiveIPv6ExtHdrs(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8)
+ shouldAccept bool
+ }{
+ {
+ name: "None",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "destination with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing - atomic fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ nextHdr, 0, 0, 0, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Fragment extension header.
+ routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, fragmentExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hop by hop (with skippable unknown) - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "routing - hop by hop (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Hop By Hop extension header with skippable unknown option.
+ nextHdr, 0, 62, 4, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with discard action for unknown option.
+ routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with discard action for unknown
+ // option.
+ nextHdr, 0, 65, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
+ udpLength := header.UDPMinimumSize + len(udpPayload)
+ extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
+ extHdrLen := len(extHdrBytes)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
+
+ // Serialize UDP message.
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), udpPayload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(udpPayload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ // Copy extension header bytes between the UDP message and the IPv6
+ // fixed header.
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+
+ // Serialize IPv6 fixed header.
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: ipv6NextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().UDP.PacketsReceived
+
+ if !test.shouldAccept {
+ if got := stats.Value(); got != 0 {
+ t.Errorf("got UDP Rx Packets = %d, want = 0", got)
+ }
+
+ return
+ }
+
+ // Expect a UDP packet.
+ if got := stats.Value(); got != 1 {
+ t.Errorf("got UDP Rx Packets = %d, want = 1", got)
+ }
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have any more UDP packets.
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
+
+// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
+type fragmentData struct {
+ nextHdr uint8
+ data buffer.VectorisedView
+}
+
+func TestReceiveIPv6Fragments(t *testing.T) {
+ const nicID = 1
+ const udpPayload1Length = 256
+ const udpPayload2Length = 128
+ const fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ const routingExtHdrLen = 8
+
+ udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ payloadLen := len(payload)
+ for i := 0; i < payloadLen; i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + payloadLen
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ var udpPayload1Buf [udpPayload1Length]byte
+ udpPayload1 := udpPayload1Buf[:]
+ ipv6Payload1 := udpGen(udpPayload1, 1)
+
+ var udpPayload2Buf [udpPayload2Length]byte
+ udpPayload2 := udpPayload2Buf[:]
+ ipv6Payload2 := udpGen(udpPayload2, 2)
+
+ tests := []struct {
+ name string
+ expectedPayload []byte
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ nextHdr: uint8(header.UDPProtocolNumber),
+ data: ipv6Payload1.ToVectorisedView(),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Atomic fragment",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with per-fragment routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with per-fragment routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
+ // fragmented traffic.
+ {
+ name: "Two fragments with atomic",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ // This fragment has the same ID as the other fragments but is an atomic
+ // fragment. It should not interfere with the other fragments.
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+
+ ipv6Payload2,
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+
+ ipv6Payload2[:32],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+
+ ipv6Payload2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, p := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
new file mode 100644
index 000000000..12b70f7e9
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -0,0 +1,906 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+// setupStackAndEndpoint creates a stack with a single NIC with a link-local
+// address llladdr and an IPv6 endpoint to a remote with link-local address
+// rlladdr
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep
+}
+
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// valid NDP NS message with the Source Link Layer Address option results in a
+// new entry in the link address cache for the sender of the message.
+func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
+ }
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNeighorSolicitationResponse(t *testing.T) {
+ const nicID = 1
+ nicAddr := lladdr0
+ remoteAddr := lladdr1
+ nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
+ nicLinkAddr := linkAddr0
+ remoteLinkAddr0 := linkAddr1
+ remoteLinkAddr1 := linkAddr2
+
+ tests := []struct {
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ }{
+ {
+ name: "Unspecified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Unspecified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source with 1 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Specified source with 2 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 2 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
+
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
+
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
+
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
+ }
+}
+
+// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option results in a
+// new entry in the link address cache for the target of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
+ },
+ {
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
+ }
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
+
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep, r
+ }
+
+ handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ if atomicFragment {
+ bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
+ bytes[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: nextHdr,
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(r, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code uint8
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetCode(test.code)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestRouterAdvertValidation tests that when the NIC is configured to handle
+// NDP Router Advertisement packets, it validates the Router Advertisement
+// properly before handling them.
+func TestRouterAdvertValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ src tcpip.Address
+ hopLimit uint8
+ code uint8
+ ndpPayload []byte
+ expectedSuccess bool
+ }{
+ {
+ "OK",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "NonLinkLocalSourceAddr",
+ addr1,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "HopLimitNot255",
+ lladdr0,
+ 254,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NonZeroCode",
+ lladdr0,
+ 255,
+ 1,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NDPPayloadTooSmall",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "OKWithOptions",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ 2, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "OptionWithZeroLength",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ // Invalid as it has 0 length.
+ 2, 0, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
+
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}