summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/network
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD19
-rw-r--r--pkg/tcpip/network/arp/BUILD34
-rw-r--r--pkg/tcpip/network/arp/arp.go170
-rw-r--r--pkg/tcpip/network/arp/arp_test.go138
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD61
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go67
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go112
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go124
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go166
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go109
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go95
-rw-r--r--pkg/tcpip/network/hash/BUILD11
-rw-r--r--pkg/tcpip/network/hash/hash.go83
-rw-r--r--pkg/tcpip/network/ip_test.go560
-rw-r--r--pkg/tcpip/network/ipv4/BUILD38
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go282
-rw-r--r--pkg/tcpip/network/ipv4/icmp_test.go124
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go233
-rw-r--r--pkg/tcpip/network/ipv6/BUILD21
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go80
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go172
21 files changed, 2699 insertions, 0 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
new file mode 100644
index 000000000..36ddaa692
--- /dev/null
+++ b/pkg/tcpip/network/BUILD
@@ -0,0 +1,19 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = [
+ "ip_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
new file mode 100644
index 000000000..e6d0899a9
--- /dev/null
+++ b/pkg/tcpip/network/arp/BUILD
@@ -0,0 +1,34 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "arp",
+ srcs = ["arp.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "arp_test",
+ size = "small",
+ srcs = ["arp_test.go"],
+ deps = [
+ ":arp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
new file mode 100644
index 000000000..4e3d7f597
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp.go
@@ -0,0 +1,170 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package arp implements the ARP network protocol. It is used to resolve
+// IPv4 addresses into link-local MAC addresses, and advertises IPv4
+// addresses of its stack with the local network.
+//
+// To use it in the networking stack, pass arp.ProtocolName as one of the
+// network protocols when calling stack.New. Then add an "arp" address to
+// every NIC on the stack that should respond to ARP requests. That is:
+//
+// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
+// // handle err
+// }
+package arp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ARP protocol name.
+ ProtocolName = "arp"
+
+ // ProtocolNumber is the ARP protocol number.
+ ProtocolNumber = header.ARPProtocolNumber
+
+ // ProtocolAddress is the address expected by the ARP endpoint.
+ ProtocolAddress = tcpip.Address("arp")
+)
+
+// endpoint implements stack.NetworkEndpoint.
+type endpoint struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+}
+
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &stack.NetworkEndpointID{ProtocolAddress}
+}
+
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.ARPSize
+}
+
+func (e *endpoint) Close() {}
+
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ h := header.ARP(v)
+ if !h.IsValid() {
+ return
+ }
+
+ switch h.Op() {
+ case header.ARPRequest:
+ localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
+ pkt := header.ARP(hdr.Prepend(header.ARPSize))
+ pkt.SetIPv4OverEthernet()
+ pkt.SetOp(header.ARPReply)
+ copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+ fallthrough // also fill the cache from requests
+ case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
+ }
+}
+
+// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
+type protocol struct {
+}
+
+func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
+func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.ARP(v)
+ return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+}
+
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addr != ProtocolAddress {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return &endpoint{
+ nicid: nicid,
+ addr: addr,
+ linkEP: sender,
+ linkAddrCache: linkAddrCache,
+ }, nil
+}
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ r := &stack.Route{
+ RemoteLinkAddress: broadcastMAC,
+ }
+
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
+ h := header.ARP(hdr.Prepend(header.ARPSize))
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), linkEP.LinkAddress())
+ copy(h.ProtocolAddressSender(), localAddr)
+ copy(h.ProtocolAddressTarget(), addr)
+
+ return linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\xff\xff\xff\xff" {
+ return broadcastMAC, true
+ }
+ return "", false
+}
+
+// SetOption implements NetworkProtocol.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
new file mode 100644
index 000000000..91ffdce4b
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -0,0 +1,138 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package arp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+)
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+}
+
+func newTestContext(t *testing.T) *testContext {
+ s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ipv4.PingProtocolName})
+
+ const defaultMTU = 65536
+ id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("AddAddress for arp failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ }})
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+func (c *testContext) cleanup() {
+ close(c.linkEP.C)
+}
+
+func TestDirectRequest(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+
+ const senderMAC = "\x01\x02\x03\x04\x05\x06"
+ const senderIPv4 = "\x0a\x00\x00\x02"
+
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), senderMAC)
+ copy(h.ProtocolAddressSender(), senderIPv4)
+
+ // stackAddr1
+ copy(h.ProtocolAddressTarget(), stackAddr1)
+ vv := v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ pkt := <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
+ }
+ rep := header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
+ t.Errorf("stackAddr1: expected sender to be set")
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
+ }
+
+ // stackAddr2
+ copy(h.ProtocolAddressTarget(), stackAddr2)
+ vv = v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ pkt = <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
+ }
+ rep = header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
+ t.Errorf("stackAddr2: expected sender to be set")
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
+ }
+
+ // stackAddrBad
+ copy(h.ProtocolAddressTarget(), stackAddrBad)
+ vv = v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ select {
+ case pkt := <-c.linkEP.C:
+ t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
+ case <-time.After(100 * time.Millisecond):
+ // Sleep tests are gross, but this will only
+ // potentially fail flakily if there's a bugj
+ // If there is no bug this will reliably succeed.
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
new file mode 100644
index 000000000..78fe878ec
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -0,0 +1,61 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "fragmentation_state",
+ srcs = ["reassembler_list.go"],
+ out = "fragmentation_state.go",
+ package = "fragmentation",
+)
+
+go_template_instance(
+ name = "reassembler_list",
+ out = "reassembler_list.go",
+ package = "fragmentation",
+ prefix = "reassembler",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*reassembler",
+ },
+)
+
+go_library(
+ name = "fragmentation",
+ srcs = [
+ "frag_heap.go",
+ "fragmentation.go",
+ "fragmentation_state.go",
+ "reassembler.go",
+ "reassembler_list.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/state",
+ "//pkg/tcpip/buffer",
+ ],
+)
+
+go_test(
+ name = "fragmentation_test",
+ size = "small",
+ srcs = [
+ "frag_heap_test.go",
+ "fragmentation_test.go",
+ "reassembler_test.go",
+ ],
+ embed = [":fragmentation"],
+ deps = ["//pkg/tcpip/buffer"],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "reassembler_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
new file mode 100644
index 000000000..2e8512909
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap.go
@@ -0,0 +1,67 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type fragment struct {
+ offset uint16
+ vv *buffer.VectorisedView
+}
+
+type fragHeap []fragment
+
+func (h *fragHeap) Len() int {
+ return len(*h)
+}
+
+func (h *fragHeap) Less(i, j int) bool {
+ return (*h)[i].offset < (*h)[j].offset
+}
+
+func (h *fragHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *fragHeap) Push(x interface{}) {
+ *h = append(*h, x.(fragment))
+}
+
+func (h *fragHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
+
+// reassamble empties the heap and returns a VectorisedView
+// containing a reassambled version of the fragments inside the heap.
+func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
+ curr := heap.Pop(h).(fragment)
+ views := curr.vv.Views()
+ size := curr.vv.Size()
+
+ if curr.offset != 0 {
+ return buffer.NewVectorisedView(0, nil), fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ }
+
+ for h.Len() > 0 {
+ curr := heap.Pop(h).(fragment)
+ if int(curr.offset) < size {
+ curr.vv.TrimFront(size - int(curr.offset))
+ } else if int(curr.offset) > size {
+ return buffer.NewVectorisedView(0, nil), fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ }
+ size += curr.vv.Size()
+ views = append(views, curr.vv.Views()...)
+ }
+ return buffer.NewVectorisedView(size, views), nil
+}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
new file mode 100644
index 000000000..218a24d7b
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go
@@ -0,0 +1,112 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "reflect"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+var reassambleTestCases = []struct {
+ comment string
+ in []fragment
+ want *buffer.VectorisedView
+}{
+ {
+ comment: "Non-overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Non-overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Duplicated packets",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(1, "0"),
+ },
+ {
+ comment: "Overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(2, "01")},
+ {offset: 1, vv: vv(2, "12")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(2, "12")},
+ {offset: 0, vv: vv(2, "01")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping subset in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(3, "012")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(3, "012"),
+ },
+ {
+ comment: "Overlapping subset out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(3, "012")},
+ },
+ want: vv(3, "012"),
+ },
+}
+
+func TestReassamble(t *testing.T) {
+ for _, c := range reassambleTestCases {
+ h := (fragHeap)(make([]fragment, 0, 8))
+ heap.Init(&h)
+ for _, f := range c.in {
+ heap.Push(&h, f)
+ }
+ got, _ := h.reassemble()
+
+ if !reflect.DeepEqual(got, *c.want) {
+ t.Errorf("Test \"%s\" reassembling failed. Got %v. Want %v", c.comment, got, *c.want)
+ }
+ }
+}
+
+func TestReassambleFailsForNonZeroOffset(t *testing.T) {
+ h := (fragHeap)(make([]fragment, 0, 8))
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when the first packet had offset != 0")
+ }
+}
+
+func TestReassambleFailsForHoles(t *testing.T) {
+ h := (fragHeap)(make([]fragment, 0, 8))
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
+ heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when there was a hole in the packet")
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
new file mode 100644
index 000000000..a309a24c5
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -0,0 +1,124 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791 and RFC 815.
+package fragmentation
+
+import (
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
+
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[uint32]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
+ reassemblers: make(map[uint32]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ }
+}
+
+// Process processes an incoming fragment beloning to an ID
+// and returns a complete packet when all the packets belonging to that ID have been received.
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool) {
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if ok && r.tooOld(f.timeout) {
+ // This is very likely to be an id-collision or someone performing a slow-rate attack.
+ f.release(r)
+ ok = false
+ }
+ if !ok {
+ r = newReassembler(id)
+ f.reassemblers[id] = r
+ f.rList.PushFront(r)
+ }
+ f.mu.Unlock()
+
+ res, done, consumed := r.process(first, last, more, vv)
+
+ f.mu.Lock()
+ f.size += consumed
+ if done {
+ f.release(r)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ tail := f.rList.Back()
+ for f.size > f.lowLimit && tail != nil {
+ f.release(tail)
+ tail = tail.Prev()
+ }
+ }
+ f.mu.Unlock()
+ return res, done
+}
+
+func (f *Fragmentation) release(r *reassembler) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.size -= r.size
+ if f.size < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
+ f.size = 0
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
new file mode 100644
index 000000000..2f0200d26
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -0,0 +1,166 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// vv is a helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) *buffer.VectorisedView {
+ views := make([]buffer.View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ vv := buffer.NewVectorisedView(size, views)
+ return &vv
+}
+
+func emptyVv() *buffer.VectorisedView {
+ vv := buffer.NewVectorisedView(0, nil)
+ return &vv
+}
+
+type processInput struct {
+ id uint32
+ first uint16
+ last uint16
+ more bool
+ vv *buffer.VectorisedView
+}
+
+type processOutput struct {
+ vv *buffer.VectorisedView
+ done bool
+}
+
+var processTestCases = []struct {
+ comment string
+ in []processInput
+ out []processOutput
+}{
+ {
+ comment: "One ID",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: emptyVv(), done: false},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+ {
+ comment: "Two IDs",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: emptyVv(), done: false},
+ {vv: emptyVv(), done: false},
+ {vv: vv(4, "ab", "cd"), done: true},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+}
+
+func TestFragmentationProcess(t *testing.T) {
+ for _, c := range processTestCases {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ for i, in := range c.in {
+ vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if !reflect.DeepEqual(vv, *(c.out[i].vv)) {
+ t.Errorf("Test \"%s\" Process() returned a wrong vv. Got %v. Want %v", c.comment, vv, *(c.out[i].vv))
+ }
+ if done != c.out[i].done {
+ t.Errorf("Test \"%s\" Process() returned a wrong done. Got %t. Want %t", c.comment, done, c.out[i].done)
+ }
+ if c.out[i].done {
+ if _, ok := f.reassemblers[in.id]; ok {
+ t.Errorf("Test \"%s\" Process() didn't remove buffer from reassemblers.", c.comment)
+ }
+ for n := f.rList.Front(); n != nil; n = n.Next() {
+ if n.id == in.id {
+ t.Errorf("Test \"%s\" Process() didn't remove buffer from rList.", c.comment)
+ }
+ }
+ }
+ }
+ }
+}
+
+func TestReassemblingTimeout(t *testing.T) {
+ timeout := time.Millisecond
+ f := NewFragmentation(1024, 512, timeout)
+ // Send first fragment with id = 0, first = 0, last = 0, and more = true.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Sleep more than the timeout.
+ time.Sleep(2 * timeout)
+ // Send another fragment that completes a packet.
+ // However, no packet should be reassembled because the fragment arrived after the timeout.
+ _, done := f.Process(0, 1, 1, false, vv(1, "1"))
+ if done {
+ t.Errorf("Fragmentation does not respect the reassembling timeout.")
+ }
+}
+
+func TestMemoryLimits(t *testing.T) {
+ f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send first fragment with id = 1.
+ f.Process(1, 0, 0, true, vv(1, "1"))
+ // Send first fragment with id = 2.
+ f.Process(2, 0, 0, true, vv(1, "2"))
+
+ // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
+ // evicted.
+ f.Process(3, 0, 0, true, vv(1, "3"))
+
+ if _, ok := f.reassemblers[0]; ok {
+ t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[1]; ok {
+ t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[3]; !ok {
+ t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
+ }
+}
+
+func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
+ f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send the same packet again.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+
+ got := f.size
+ want := 1
+ if got != want {
+ t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
+ }
+}
+
+func TestFragmentationViewsDoNotEscape(t *testing.T) {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ in := vv(2, "0", "1")
+ f.Process(0, 0, 1, true, in)
+ // Modify input view.
+ in.RemoveFirst()
+ got, _ := f.Process(0, 2, 2, false, vv(1, "2"))
+ want := vv(3, "0", "1", "2")
+ if !reflect.DeepEqual(got, *want) {
+ t.Errorf("Process() returned a wrong vv. Got %v. Want %v", got, *want)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
new file mode 100644
index 000000000..0267a575d
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -0,0 +1,109 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ deleted bool
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id uint32
+ size int
+ mu sync.Mutex
+ holes []hole
+ deleted int
+ heap fragHeap
+ done bool
+ creationTime time.Time
+}
+
+func newReassembler(id uint32) *reassembler {
+ r := &reassembler{
+ id: id,
+ holes: make([]hole, 0, 16),
+ deleted: 0,
+ heap: make(fragHeap, 0, 8),
+ creationTime: time.Now(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ deleted: false})
+ return r
+}
+
+// updateHoles updates the list of holes for an incoming fragment and
+// returns true iff the fragment filled at least part of an existing hole.
+func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
+ used := false
+ for i := range r.holes {
+ if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ continue
+ }
+ used = true
+ r.deleted++
+ r.holes[i].deleted = true
+ if first > r.holes[i].first {
+ r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+ }
+ if last < r.holes[i].last && more {
+ r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ }
+ }
+ return used
+}
+
+func (r *reassembler) process(first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ consumed := 0
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.NewVectorisedView(0, nil), false, consumed
+ }
+ if r.updateHoles(first, last, more) {
+ // We store the incoming packet only if it filled some holes.
+ uu := vv.Clone(nil)
+ heap.Push(&r.heap, fragment{offset: first, vv: &uu})
+ consumed = vv.Size()
+ r.size += consumed
+ }
+ // Check if all the holes have been deleted and we are ready to reassamble.
+ if r.deleted < len(r.holes) {
+ return buffer.NewVectorisedView(0, nil), false, consumed
+ }
+ res, err := r.heap.reassemble()
+ if err != nil {
+ panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ }
+ return res, true, consumed
+}
+
+func (r *reassembler) tooOld(timeout time.Duration) bool {
+ return time.Now().Sub(r.creationTime) > timeout
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
new file mode 100644
index 000000000..b64604383
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -0,0 +1,95 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "math"
+ "reflect"
+ "testing"
+)
+
+type updateHolesInput struct {
+ first uint16
+ last uint16
+ more bool
+}
+
+var holesTestCases = []struct {
+ comment string
+ in []updateHolesInput
+ want []hole
+}{
+ {
+ comment: "No fragments. Expected holes: {[0 -> inf]}.",
+ in: []updateHolesInput{},
+ want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
+ },
+ {
+ comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ {first: 3, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment at the end. Expected holes: {[0, 0]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 1, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 2, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 3, last: math.MaxUint16, deleted: true},
+ },
+ },
+}
+
+func TestUpdateHoles(t *testing.T) {
+ for _, c := range holesTestCases {
+ r := newReassembler(0)
+ for _, i := range c.in {
+ r.updateHoles(i.first, i.last, i.more)
+ }
+ if !reflect.DeepEqual(r.holes, c.want) {
+ t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
new file mode 100644
index 000000000..96805c690
--- /dev/null
+++ b/pkg/tcpip/network/hash/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "hash",
+ srcs = ["hash.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash",
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip/header"],
+)
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
new file mode 100644
index 000000000..e5a696158
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash.go
@@ -0,0 +1,83 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package hash contains utility functions for hashing.
+package hash
+
+import (
+ "crypto/rand"
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+var hashIV = RandN32(1)[0]
+
+// RandN32 generates a slice of n cryptographic random 32-bit numbers.
+func RandN32(n int) []uint32 {
+ b := make([]byte, 4*n)
+ if _, err := rand.Read(b); err != nil {
+ panic("unable to get random numbers: " + err.Error())
+ }
+ r := make([]uint32, n)
+ for i := range r {
+ r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
+ }
+ return r
+}
+
+// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
+// from linux.
+func Hash3Words(a, b, c, initval uint32) uint32 {
+ const iv = 0xdeadbeef + (3 << 2)
+ initval += iv
+
+ a += initval
+ b += initval
+ c += initval
+
+ c ^= b
+ c -= rol32(b, 14)
+ a ^= c
+ a -= rol32(c, 11)
+ b ^= a
+ b -= rol32(a, 25)
+ c ^= b
+ c -= rol32(b, 16)
+ a ^= c
+ a -= rol32(c, 4)
+ b ^= a
+ b -= rol32(a, 14)
+ c ^= b
+ c -= rol32(b, 24)
+
+ return c
+}
+
+// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
+func IPv4FragmentHash(h header.IPv4) uint32 {
+ x := uint32(h.ID())<<16 | uint32(h.Protocol())
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(x, y, z, hashIV)
+}
+
+// IPv6FragmentHash computes the hash of the ipv6 fragment.
+// Unlike IPv4, the protocol is not used to compute the hash.
+// RFC 2640 (sec 4.5) is not very sharp on this aspect.
+// As a reference, also Linux ignores the protocol to compute
+// the hash (inet6_hash_frag).
+func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(f.ID(), y, z, hashIV)
+}
+
+func rol32(v, shift uint32) uint32 {
+ return (v << shift) | (v >> ((-shift) & 31))
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
new file mode 100644
index 000000000..797501858
--- /dev/null
+++ b/pkg/tcpip/network/ip_test.go
@@ -0,0 +1,560 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ip_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
+// The former is used to pretend that it's a link endpoint so that we can
+// inspect packets written by the network endpoints. The latter is used to
+// pretend that it's the network stack so that it can inspect incoming packets
+// that have been handled by the network endpoints.
+//
+// Packets are checked by comparing their fields/values against the expected
+// values stored in the test object itself.
+type testObject struct {
+ t *testing.T
+ protocol tcpip.TransportProtocolNumber
+ contents []byte
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ v4 bool
+ typ stack.ControlType
+ extra uint32
+
+ dataCalls int
+ controlCalls int
+}
+
+// checkValues verifies that the transport protocol, data contents, src & dst
+// addresses of a packet match what's expected. If any field doesn't match, the
+// test fails.
+func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
+ v := vv.ToView()
+ if protocol != t.protocol {
+ t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
+ }
+
+ if srcAddr != t.srcAddr {
+ t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
+ }
+
+ if dstAddr != t.dstAddr {
+ t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
+ }
+
+ if len(v) != len(t.contents) {
+ t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
+ }
+
+ for i := range t.contents {
+ if t.contents[i] != v[i] {
+ t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
+ }
+ }
+}
+
+// DeliverTransportPacket is called by network endpoints after parsing incoming
+// packets. This is used by the test object to verify that the results of the
+// parsing are expected.
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+ t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
+ t.dataCalls++
+}
+
+// DeliverTransportControlPacket is called by network endpoints after parsing
+// incoming control (ICMP) packets. This is used by the test object to verify
+// that the results of the parsing are expected.
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ t.checkValues(trans, vv, remote, local)
+ if typ != t.typ {
+ t.t.Errorf("typ = %v, want %v", typ, t.typ)
+ }
+ if extra != t.extra {
+ t.t.Errorf("extra = %v, want %v", extra, t.extra)
+ }
+ t.controlCalls++
+}
+
+// Attach is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) Attach(stack.NetworkDispatcher) {}
+
+// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
+// matches the linux loopback MTU.
+func (*testObject) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testObject) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// WritePacket is called by network endpoints after producing a packet and
+// writing it to the link endpoint. This is used by the test object to verify
+// that the produced packet is as expected.
+func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ var prot tcpip.TransportProtocolNumber
+ var srcAddr tcpip.Address
+ var dstAddr tcpip.Address
+
+ if t.v4 {
+ h := header.IPv4(hdr.UsedBytes())
+ prot = tcpip.TransportProtocolNumber(h.Protocol())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+
+ } else {
+ h := header.IPv6(hdr.UsedBytes())
+ prot = tcpip.TransportProtocolNumber(h.NextHeader())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+ }
+ var views [1]buffer.View
+ vv := payload.ToVectorisedView(views)
+ t.checkValues(prot, &vv, srcAddr, dstAddr)
+ return nil
+}
+
+func TestIPv4Send(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, nil, &o)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = "\x0a\x00\x00\x01"
+ o.dstAddr = "\x0a\x00\x00\x02"
+ o.contents = payload
+
+ r := stack.Route{
+ RemoteAddress: o.dstAddr,
+ LocalAddress: o.srcAddr,
+ }
+ if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv4Receive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ SrcAddr: "\x0a\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x01"
+ o.contents = view[header.IPv4MinimumSize:totalLen]
+
+ r := stack.Route{
+ LocalAddress: o.dstAddr,
+ RemoteAddress: o.srcAddr,
+ }
+ var views [1]buffer.View
+ vv := view.ToVectorisedView(views)
+ ep.HandlePacket(&r, &vv)
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv4ReceiveControl(t *testing.T) {
+ const mtu = 0xbeef - header.IPv4MinimumSize
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset uint16
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
+ {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8},
+ {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8},
+ }
+ r := stack.Route{
+ LocalAddress: "\x0a\x00\x00\x01",
+ RemoteAddress: "\x0a\x00\x00\xbb",
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ var views [1]buffer.View
+ o := testObject{t: t}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv4 header.
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(view) - c.trunc),
+ TTL: 20,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: "\x0a\x00\x00\xbb",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
+ icmp.SetType(header.ICMPv4DstUnreachable)
+ icmp.SetCode(c.code)
+ copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+
+ // Create the inner IPv4 header.
+ ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:])
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: 100,
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: c.fragmentOffset,
+ SrcAddr: "\x0a\x00\x00\x01",
+ DstAddr: "\x0a\x00\x00\x02",
+ })
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv4 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x01"
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view.ToVectorisedView(views)
+ vv.CapLength(len(view) - c.trunc)
+ ep.HandlePacket(&r, &vv)
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
+
+func TestIPv4FragmentationReceive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 24
+
+ frag1 := buffer.NewView(totalLen)
+ ip1 := header.IPv4(frag1)
+ ip1.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 0,
+ Flags: header.IPv4FlagMoreFragments,
+ SrcAddr: "\x0a\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag1[i] = uint8(i)
+ }
+
+ frag2 := buffer.NewView(totalLen)
+ ip2 := header.IPv4(frag2)
+ ip2.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 24,
+ SrcAddr: "\x0a\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag2[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x01"
+ o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+
+ r := stack.Route{
+ LocalAddress: o.dstAddr,
+ RemoteAddress: o.srcAddr,
+ }
+
+ // Send first segment.
+ var views1 [1]buffer.View
+ vv1 := frag1.ToVectorisedView(views1)
+ ep.HandlePacket(&r, &vv1)
+ if o.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ }
+
+ // Send second segment.
+ var views2 [1]buffer.View
+ vv2 := frag2.ToVectorisedView(views2)
+ ep.HandlePacket(&r, &vv2)
+
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6Send(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, nil, &o)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.contents = payload
+
+ r := stack.Route{
+ RemoteAddress: o.dstAddr,
+ LocalAddress: o.srcAddr,
+ }
+ if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv6Receive(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv6MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.contents = view[header.IPv6MinimumSize:totalLen]
+
+ r := stack.Route{
+ LocalAddress: o.dstAddr,
+ RemoteAddress: o.srcAddr,
+ }
+ var views [1]buffer.View
+ vv := view.ToVectorisedView(views)
+ ep.HandlePacket(&r, &vv)
+
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6ReceiveControl(t *testing.T) {
+ newUint16 := func(v uint16) *uint16 { return &v }
+
+ const mtu = 0xffff
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset *uint16
+ typ header.ICMPv6Type
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
+ {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
+ {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
+ {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
+ }
+ r := stack.Route{
+ LocalAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ var views [1]buffer.View
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ defer ep.Close()
+
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
+ if c.fragmentOffset != nil {
+ dataOffset += header.IPv6FragmentHeaderSize
+ }
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv6 header.
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
+ icmp.SetType(c.typ)
+ icmp.SetCode(c.code)
+ copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02",
+ })
+
+ // Build the fragmentation header if needed.
+ if c.fragmentOffset != nil {
+ ip.SetNextHeader(header.IPv6FragmentHeader)
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ frag.Encode(&header.IPv6FragmentFields{
+ NextHeader: 10,
+ FragmentOffset: *c.fragmentOffset,
+ M: true,
+ Identification: 0x12345678,
+ })
+ }
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv6 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view.ToVectorisedView(views)
+ vv.CapLength(len(view) - c.trunc)
+ ep.HandlePacket(&r, &vv)
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
new file mode 100644
index 000000000..9df113df1
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -0,0 +1,38 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "ipv4",
+ srcs = [
+ "icmp.go",
+ "ipv4.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "ipv4_test",
+ size = "small",
+ srcs = ["icmp_test.go"],
+ deps = [
+ ":ipv4",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
new file mode 100644
index 000000000..ffd761350
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -0,0 +1,282 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv4
+
+import (
+ "context"
+ "encoding/binary"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// PingProtocolName is a pseudo transport protocol used to handle ping replies.
+// Use it when constructing a stack that intends to use ipv4.Ping.
+const PingProtocolName = "icmpv4ping"
+
+// pingProtocolNumber is a fake transport protocol used to
+// deliver incoming ICMP echo replies. The ICMP identifier
+// number is used as a port number for multiplexing.
+const pingProtocolNumber tcpip.TransportProtocolNumber = 256 + 11
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ vv.TrimFront(hlen)
+ p := h.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ if len(v) < header.ICMPv4MinimumSize {
+ return
+ }
+ h := header.ICMPv4(v)
+
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv4MinimumSize)
+ req := echoRequest{r: r.Clone(), v: vv.ToView()}
+ select {
+ case e.echoRequests <- req:
+ default:
+ req.r.Release()
+ }
+
+ case header.ICMPv4EchoReply:
+ e.dispatcher.DeliverTransportPacket(r, pingProtocolNumber, vv)
+
+ case header.ICMPv4DstUnreachable:
+ if len(v) < header.ICMPv4DstUnreachableMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ }
+ }
+ // TODO: Handle other ICMP types.
+}
+
+type echoRequest struct {
+ r stack.Route
+ v buffer.View
+}
+
+func (e *endpoint) echoReplier() {
+ for req := range e.echoRequests {
+ sendICMPv4(&req.r, header.ICMPv4EchoReply, 0, req.v)
+ req.r.Release()
+ }
+}
+
+func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error {
+ hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpv4.SetType(typ)
+ icmpv4.SetCode(code)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber)
+}
+
+// A Pinger can send echo requests to an address.
+type Pinger struct {
+ Stack *stack.Stack
+ NICID tcpip.NICID
+ Addr tcpip.Address
+ LocalAddr tcpip.Address // optional
+ Wait time.Duration // if zero, defaults to 1 second
+ Count uint16 // if zero, defaults to MaxUint16
+}
+
+// Ping sends echo requests to an ICMPv4 endpoint.
+// Responses are streamed to the channel ch.
+func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
+ count := p.Count
+ if count == 0 {
+ count = 1<<16 - 1
+ }
+ wait := p.Wait
+ if wait == 0 {
+ wait = 1 * time.Second
+ }
+
+ r, err := p.Stack.FindRoute(p.NICID, p.LocalAddr, p.Addr, ProtocolNumber)
+ if err != nil {
+ return err
+ }
+
+ netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
+ ep := &pingEndpoint{
+ stack: p.Stack,
+ pktCh: make(chan buffer.View, 1),
+ }
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ RemoteAddress: p.Addr,
+ }
+
+ _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
+ id.LocalPort = port
+ err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+ if err != nil {
+ return err
+ }
+ defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id)
+
+ v := buffer.NewView(4)
+ binary.BigEndian.PutUint16(v[0:], id.LocalPort)
+
+ start := time.Now()
+
+ done := make(chan struct{})
+ go func(count int) {
+ loop:
+ for ; count > 0; count-- {
+ select {
+ case v := <-ep.pktCh:
+ seq := binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize+2:])
+ ch <- PingReply{
+ Duration: time.Since(start) - time.Duration(seq)*wait,
+ SeqNumber: seq,
+ }
+ case <-ctx.Done():
+ break loop
+ }
+ }
+ close(done)
+ }(int(count))
+ defer func() { <-done }()
+
+ t := time.NewTicker(wait)
+ defer t.Stop()
+ for seq := uint16(0); seq < count; seq++ {
+ select {
+ case <-t.C:
+ case <-ctx.Done():
+ return nil
+ }
+ binary.BigEndian.PutUint16(v[2:], seq)
+ sent := time.Now()
+ if err := sendICMPv4(&r, header.ICMPv4Echo, 0, v); err != nil {
+ ch <- PingReply{
+ Error: err,
+ Duration: time.Since(sent),
+ SeqNumber: seq,
+ }
+ }
+ }
+ return nil
+}
+
+// PingReply summarizes an ICMP echo reply.
+type PingReply struct {
+ Error *tcpip.Error // reports any errors sending a ping request
+ Duration time.Duration
+ SeqNumber uint16
+}
+
+type pingProtocol struct{}
+
+func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrNotSupported // endpoints are created directly
+}
+
+func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return pingProtocolNumber }
+
+func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
+
+func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ ident := binary.BigEndian.Uint16(v[4:])
+ return 0, ident, nil
+}
+
+func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *pingProtocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *pingProtocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(PingProtocolName, func() stack.TransportProtocol {
+ return &pingProtocol{}
+ })
+}
+
+type pingEndpoint struct {
+ stack *stack.Stack
+ pktCh chan buffer.View
+}
+
+func (e *pingEndpoint) Close() {
+ close(e.pktCh)
+}
+
+func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+ select {
+ case e.pktCh <- vv.ToView():
+ default:
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *pingEndpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/network/ipv4/icmp_test.go b/pkg/tcpip/network/ipv4/icmp_test.go
new file mode 100644
index 000000000..378fba74b
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp_test.go
@@ -0,0 +1,124 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv4_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const stackAddr = "\x0a\x00\x00\x01"
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+}
+
+func newTestContext(t *testing.T) *testContext {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName})
+
+ const defaultMTU = 65536
+ id, linkEP := channel.New(256, defaultMTU, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ }})
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+func (c *testContext) cleanup() {
+ close(c.linkEP.C)
+}
+
+func (c *testContext) loopback() {
+ go func() {
+ for pkt := range c.linkEP.C {
+ v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
+ copy(v, pkt.Header)
+ copy(v[len(pkt.Header):], pkt.Payload)
+ vv := v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(pkt.Proto, &vv)
+ }
+ }()
+}
+
+func TestEcho(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+ c.loopback()
+
+ ch := make(chan ipv4.PingReply, 1)
+ p := ipv4.Pinger{
+ Stack: c.s,
+ NICID: 1,
+ Addr: stackAddr,
+ Wait: 10 * time.Millisecond,
+ Count: 1, // one ping only
+ }
+ if err := p.Ping(context.Background(), ch); err != nil {
+ t.Fatalf("icmp.Ping failed: %v", err)
+ }
+
+ ping := <-ch
+ if ping.Error != nil {
+ t.Errorf("bad ping response: %v", ping.Error)
+ }
+}
+
+func TestEchoSequence(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+ c.loopback()
+
+ const numPings = 3
+ ch := make(chan ipv4.PingReply, numPings)
+ p := ipv4.Pinger{
+ Stack: c.s,
+ NICID: 1,
+ Addr: stackAddr,
+ Wait: 10 * time.Millisecond,
+ Count: numPings,
+ }
+ if err := p.Ping(context.Background(), ch); err != nil {
+ t.Fatalf("icmp.Ping failed: %v", err)
+ }
+
+ for i := uint16(0); i < numPings; i++ {
+ ping := <-ch
+ if ping.Error != nil {
+ t.Errorf("i=%d bad ping response: %v", i, ping.Error)
+ }
+ if ping.SeqNumber != i {
+ t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
new file mode 100644
index 000000000..4cc2a2fd4
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -0,0 +1,233 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package ipv4 contains the implementation of the ipv4 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv4
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv4 protocol name.
+ ProtocolName = "ipv4"
+
+ // ProtocolNumber is the ipv4 protocol number.
+ ProtocolNumber = header.IPv4ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // TotalLength field of the ipv4 header.
+ maxTotalSize = 0xffff
+
+ // buckets is the number of identifier buckets.
+ buckets = 2048
+)
+
+type address [header.IPv4AddressSize]byte
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ address address
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ echoRequests chan echoRequest
+ fragmentation *fragmentation.Fragmentation
+}
+
+func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
+ e := &endpoint{
+ nicid: nicid,
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ echoRequests: make(chan echoRequest, 10),
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
+ copy(e.address[:], addr)
+ e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
+
+ go e.echoReplier()
+
+ return e
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv4 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ length := uint16(hdr.UsedLength() + len(payload))
+ id := uint32(0)
+ if length > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: length,
+ ID: uint16(id),
+ TTL: 65,
+ Protocol: uint8(protocol),
+ SrcAddr: tcpip.Address(e.address[:]),
+ DstAddr: r.RemoteAddress,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
+}
+
+// HandlePacket is called by the link layer when new ipv4 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ tlen := int(h.TotalLength())
+ vv.TrimFront(hlen)
+ vv.CapLength(tlen - hlen)
+
+ more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
+ if more || h.FragmentOffset() != 0 {
+ // The packet is a fragment, let's try to reassemble it.
+ last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ tt, ready := e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if !ready {
+ return
+ }
+ vv = &tt
+ }
+ p := h.TransportProtocol()
+ if p == header.ICMPv4ProtocolNumber {
+ e.handleICMP(r, vv)
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, p, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (e *endpoint) Close() {
+ close(e.echoRequests)
+}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv4 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv4 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv4MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv4(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv4 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return newEndpoint(nicid, addr, dispatcher, linkEP), nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > maxTotalSize {
+ mtu = maxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address, the transport protocol number, and a random initial
+// value (generated once on initialization) to generate the hash.
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+ t := r.LocalAddress
+ a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = r.RemoteAddress
+ b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return hash.Hash3Words(a, b, uint32(protocol), hashIV)
+}
+
+var (
+ ids []uint32
+ hashIV uint32
+)
+
+func init() {
+ ids = make([]uint32, buckets)
+
+ // Randomly initialize hashIV and the ids.
+ r := hash.RandN32(1 + buckets)
+ for i := range ids {
+ ids[i] = r[i]
+ }
+ hashIV = r[buckets]
+
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
new file mode 100644
index 000000000..db7da0af3
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -0,0 +1,21 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "ipv6",
+ srcs = [
+ "icmp.go",
+ "ipv6.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
new file mode 100644
index 000000000..0fc6dcce2
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,80 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv6
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ vv.TrimFront(header.IPv6MinimumSize)
+ p := h.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f := header.IPv6Fragment(vv.First())
+ if !f.IsValid() || f.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ vv.TrimFront(header.IPv6FragmentHeaderSize)
+ p = f.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ if len(v) < header.ICMPv6MinimumSize {
+ return
+ }
+ h := header.ICMPv6(v)
+
+ switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+
+ case header.ICMPv6DstUnreachable:
+ if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
new file mode 100644
index 000000000..15654cbbd
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -0,0 +1,172 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package ipv6 contains the implementation of the ipv6 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv6
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv6 protocol name.
+ ProtocolName = "ipv6"
+
+ // ProtocolNumber is the ipv6 protocol number.
+ ProtocolNumber = header.IPv6ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // PayloadLength field of the ipv6 header.
+ maxPayloadSize = 0xffff
+)
+
+type address [header.IPv6AddressSize]byte
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ address address
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+}
+
+func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
+ e := &endpoint{nicid: nicid, linkEP: linkEP, dispatcher: dispatcher}
+ copy(e.address[:], addr)
+ e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
+ return e
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv6 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ length := uint16(hdr.UsedLength())
+ if payload != nil {
+ length += uint16(len(payload))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(protocol),
+ HopLimit: 65,
+ SrcAddr: tcpip.Address(e.address[:]),
+ DstAddr: r.RemoteAddress,
+ })
+
+ return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
+}
+
+// HandlePacket is called by the link layer when new ipv6 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ vv.TrimFront(header.IPv6MinimumSize)
+ vv.CapLength(int(h.PayloadLength()))
+
+ p := h.TransportProtocol()
+ if p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, vv)
+ return
+ }
+
+ e.dispatcher.DeliverTransportPacket(r, p, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (*endpoint) Close() {}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv6 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv6 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv6 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return newEndpoint(nicid, addr, dispatcher, linkEP), nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}