diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/network | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/network')
21 files changed, 2699 insertions, 0 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD new file mode 100644 index 000000000..36ddaa692 --- /dev/null +++ b/pkg/tcpip/network/BUILD @@ -0,0 +1,19 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "ip_test", + size = "small", + srcs = [ + "ip_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD new file mode 100644 index 000000000..e6d0899a9 --- /dev/null +++ b/pkg/tcpip/network/arp/BUILD @@ -0,0 +1,34 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "arp", + srcs = ["arp.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "arp_test", + size = "small", + srcs = ["arp_test.go"], + deps = [ + ":arp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go new file mode 100644 index 000000000..4e3d7f597 --- /dev/null +++ b/pkg/tcpip/network/arp/arp.go @@ -0,0 +1,170 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package arp implements the ARP network protocol. It is used to resolve +// IPv4 addresses into link-local MAC addresses, and advertises IPv4 +// addresses of its stack with the local network. +// +// To use it in the networking stack, pass arp.ProtocolName as one of the +// network protocols when calling stack.New. Then add an "arp" address to +// every NIC on the stack that should respond to ARP requests. That is: +// +// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { +// // handle err +// } +package arp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ARP protocol name. + ProtocolName = "arp" + + // ProtocolNumber is the ARP protocol number. + ProtocolNumber = header.ARPProtocolNumber + + // ProtocolAddress is the address expected by the ARP endpoint. + ProtocolAddress = tcpip.Address("arp") +) + +// endpoint implements stack.NetworkEndpoint. +type endpoint struct { + nicid tcpip.NICID + addr tcpip.Address + linkEP stack.LinkEndpoint + linkAddrCache stack.LinkAddressCache +} + +func (e *endpoint) MTU() uint32 { + lmtu := e.linkEP.MTU() + return lmtu - uint32(e.MaxHeaderLength()) +} + +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &stack.NetworkEndpointID{ProtocolAddress} +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.ARPSize +} + +func (e *endpoint) Close() {} + +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { + v := vv.First() + h := header.ARP(v) + if !h.IsValid() { + return + } + + switch h.Op() { + case header.ARPRequest: + localAddr := tcpip.Address(h.ProtocolAddressTarget()) + if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) + pkt := header.ARP(hdr.Prepend(header.ARPSize)) + pkt.SetIPv4OverEthernet() + pkt.SetOp(header.ARPReply) + copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, &hdr, nil, ProtocolNumber) + fallthrough // also fill the cache from requests + case header.ARPReply: + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + } +} + +// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. +type protocol struct { +} + +func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } +func (p *protocol) MinimumPacketSize() int { return header.ARPSize } + +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.ARP(v) + return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +} + +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + if addr != ProtocolAddress { + return nil, tcpip.ErrBadLocalAddress + } + return &endpoint{ + nicid: nicid, + addr: addr, + linkEP: sender, + linkAddrCache: linkAddrCache, + }, nil +} + +// LinkAddressProtocol implements stack.LinkAddressResolver. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv4ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + r := &stack.Route{ + RemoteLinkAddress: broadcastMAC, + } + + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) + h := header.ARP(hdr.Prepend(header.ARPSize)) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), linkEP.LinkAddress()) + copy(h.ProtocolAddressSender(), localAddr) + copy(h.ProtocolAddressTarget(), addr) + + return linkEP.WritePacket(r, &hdr, nil, ProtocolNumber) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\xff\xff\xff\xff" { + return broadcastMAC, true + } + return "", false +} + +// SetOption implements NetworkProtocol. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go new file mode 100644 index 000000000..91ffdce4b --- /dev/null +++ b/pkg/tcpip/network/arp/arp_test.go @@ -0,0 +1,138 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package arp_test + +import ( + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") +) + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack +} + +func newTestContext(t *testing.T) *testContext { + s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ipv4.PingProtocolName}) + + const defaultMTU = 65536 + id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("AddAddress for arp failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + close(c.linkEP.C) +} + +func TestDirectRequest(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + + const senderMAC = "\x01\x02\x03\x04\x05\x06" + const senderIPv4 = "\x0a\x00\x00\x02" + + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), senderMAC) + copy(h.ProtocolAddressSender(), senderIPv4) + + // stackAddr1 + copy(h.ProtocolAddressTarget(), stackAddr1) + vv := v.ToVectorisedView([1]buffer.View{}) + c.linkEP.Inject(arp.ProtocolNumber, &vv) + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { + t.Errorf("stackAddr1: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) + } + + // stackAddr2 + copy(h.ProtocolAddressTarget(), stackAddr2) + vv = v.ToVectorisedView([1]buffer.View{}) + c.linkEP.Inject(arp.ProtocolNumber, &vv) + pkt = <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep = header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { + t.Errorf("stackAddr2: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) + } + + // stackAddrBad + copy(h.ProtocolAddressTarget(), stackAddrBad) + vv = v.ToVectorisedView([1]buffer.View{}) + c.linkEP.Inject(arp.ProtocolNumber, &vv) + select { + case pkt := <-c.linkEP.C: + t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) + case <-time.After(100 * time.Millisecond): + // Sleep tests are gross, but this will only + // potentially fail flakily if there's a bugj + // If there is no bug this will reliably succeed. + } +} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD new file mode 100644 index 000000000..78fe878ec --- /dev/null +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -0,0 +1,61 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "fragmentation_state", + srcs = ["reassembler_list.go"], + out = "fragmentation_state.go", + package = "fragmentation", +) + +go_template_instance( + name = "reassembler_list", + out = "reassembler_list.go", + package = "fragmentation", + prefix = "reassembler", + template = "//pkg/ilist:generic_list", + types = { + "Linker": "*reassembler", + }, +) + +go_library( + name = "fragmentation", + srcs = [ + "frag_heap.go", + "fragmentation.go", + "fragmentation_state.go", + "reassembler.go", + "reassembler_list.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation", + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//pkg/state", + "//pkg/tcpip/buffer", + ], +) + +go_test( + name = "fragmentation_test", + size = "small", + srcs = [ + "frag_heap_test.go", + "fragmentation_test.go", + "reassembler_test.go", + ], + embed = [":fragmentation"], + deps = ["//pkg/tcpip/buffer"], +) + +filegroup( + name = "autogen", + srcs = [ + "reassembler_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go new file mode 100644 index 000000000..2e8512909 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/frag_heap.go @@ -0,0 +1,67 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fragmentation + +import ( + "container/heap" + "fmt" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +type fragment struct { + offset uint16 + vv *buffer.VectorisedView +} + +type fragHeap []fragment + +func (h *fragHeap) Len() int { + return len(*h) +} + +func (h *fragHeap) Less(i, j int) bool { + return (*h)[i].offset < (*h)[j].offset +} + +func (h *fragHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *fragHeap) Push(x interface{}) { + *h = append(*h, x.(fragment)) +} + +func (h *fragHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} + +// reassamble empties the heap and returns a VectorisedView +// containing a reassambled version of the fragments inside the heap. +func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { + curr := heap.Pop(h).(fragment) + views := curr.vv.Views() + size := curr.vv.Size() + + if curr.offset != 0 { + return buffer.NewVectorisedView(0, nil), fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) + } + + for h.Len() > 0 { + curr := heap.Pop(h).(fragment) + if int(curr.offset) < size { + curr.vv.TrimFront(size - int(curr.offset)) + } else if int(curr.offset) > size { + return buffer.NewVectorisedView(0, nil), fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) + } + size += curr.vv.Size() + views = append(views, curr.vv.Views()...) + } + return buffer.NewVectorisedView(size, views), nil +} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go new file mode 100644 index 000000000..218a24d7b --- /dev/null +++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go @@ -0,0 +1,112 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fragmentation + +import ( + "container/heap" + "reflect" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +var reassambleTestCases = []struct { + comment string + in []fragment + want *buffer.VectorisedView +}{ + { + comment: "Non-overlapping in-order", + in: []fragment{ + {offset: 0, vv: vv(1, "0")}, + {offset: 1, vv: vv(1, "1")}, + }, + want: vv(2, "0", "1"), + }, + { + comment: "Non-overlapping out-of-order", + in: []fragment{ + {offset: 1, vv: vv(1, "1")}, + {offset: 0, vv: vv(1, "0")}, + }, + want: vv(2, "0", "1"), + }, + { + comment: "Duplicated packets", + in: []fragment{ + {offset: 0, vv: vv(1, "0")}, + {offset: 0, vv: vv(1, "0")}, + }, + want: vv(1, "0"), + }, + { + comment: "Overlapping in-order", + in: []fragment{ + {offset: 0, vv: vv(2, "01")}, + {offset: 1, vv: vv(2, "12")}, + }, + want: vv(3, "01", "2"), + }, + { + comment: "Overlapping out-of-order", + in: []fragment{ + {offset: 1, vv: vv(2, "12")}, + {offset: 0, vv: vv(2, "01")}, + }, + want: vv(3, "01", "2"), + }, + { + comment: "Overlapping subset in-order", + in: []fragment{ + {offset: 0, vv: vv(3, "012")}, + {offset: 1, vv: vv(1, "1")}, + }, + want: vv(3, "012"), + }, + { + comment: "Overlapping subset out-of-order", + in: []fragment{ + {offset: 1, vv: vv(1, "1")}, + {offset: 0, vv: vv(3, "012")}, + }, + want: vv(3, "012"), + }, +} + +func TestReassamble(t *testing.T) { + for _, c := range reassambleTestCases { + h := (fragHeap)(make([]fragment, 0, 8)) + heap.Init(&h) + for _, f := range c.in { + heap.Push(&h, f) + } + got, _ := h.reassemble() + + if !reflect.DeepEqual(got, *c.want) { + t.Errorf("Test \"%s\" reassembling failed. Got %v. Want %v", c.comment, got, *c.want) + } + } +} + +func TestReassambleFailsForNonZeroOffset(t *testing.T) { + h := (fragHeap)(make([]fragment, 0, 8)) + heap.Init(&h) + heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) + _, err := h.reassemble() + if err == nil { + t.Errorf("reassemble() did not fail when the first packet had offset != 0") + } +} + +func TestReassambleFailsForHoles(t *testing.T) { + h := (fragHeap)(make([]fragment, 0, 8)) + heap.Init(&h) + heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) + heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) + _, err := h.reassemble() + if err == nil { + t.Errorf("reassemble() did not fail when there was a hole in the packet") + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go new file mode 100644 index 000000000..a309a24c5 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -0,0 +1,124 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fragmentation contains the implementation of IP fragmentation. +// It is based on RFC 791 and RFC 815. +package fragmentation + +import ( + "log" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. +const DefaultReassembleTimeout = 30 * time.Second + +// HighFragThreshold is the threshold at which we start trimming old +// fragmented packets. Linux uses a default value of 4 MB. See +// net.ipv4.ipfrag_high_thresh for more information. +const HighFragThreshold = 4 << 20 // 4MB + +// LowFragThreshold is the threshold we reach to when we start dropping +// older fragmented packets. It's important that we keep enough room for newer +// packets to be re-assembled. Hence, this needs to be lower than +// HighFragThreshold enough. Linux uses a default value of 3 MB. See +// net.ipv4.ipfrag_low_thresh for more information. +const LowFragThreshold = 3 << 20 // 3MB + +// Fragmentation is the main structure that other modules +// of the stack should use to implement IP Fragmentation. +type Fragmentation struct { + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[uint32]*reassembler + rList reassemblerList + size int + timeout time.Duration +} + +// NewFragmentation creates a new Fragmentation. +// +// highMemoryLimit specifies the limit on the memory consumed +// by the fragments stored by Fragmentation (overhead of internal data-structures +// is not accounted). Fragments are dropped when the limit is reached. +// +// lowMemoryLimit specifies the limit on which we will reach by dropping +// fragments after reaching highMemoryLimit. +// +// reassemblingTimeout specifes the maximum time allowed to reassemble a packet. +// Fragments are lazily evicted only when a new a packet with an +// already existing fragmentation-id arrives after the timeout. +func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { + if lowMemoryLimit >= highMemoryLimit { + lowMemoryLimit = highMemoryLimit + } + + if lowMemoryLimit < 0 { + lowMemoryLimit = 0 + } + + return &Fragmentation{ + reassemblers: make(map[uint32]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + } +} + +// Process processes an incoming fragment beloning to an ID +// and returns a complete packet when all the packets belonging to that ID have been received. +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool) { + f.mu.Lock() + r, ok := f.reassemblers[id] + if ok && r.tooOld(f.timeout) { + // This is very likely to be an id-collision or someone performing a slow-rate attack. + f.release(r) + ok = false + } + if !ok { + r = newReassembler(id) + f.reassemblers[id] = r + f.rList.PushFront(r) + } + f.mu.Unlock() + + res, done, consumed := r.process(first, last, more, vv) + + f.mu.Lock() + f.size += consumed + if done { + f.release(r) + } + // Evict reassemblers if we are consuming more memory than highLimit until + // we reach lowLimit. + if f.size > f.highLimit { + tail := f.rList.Back() + for f.size > f.lowLimit && tail != nil { + f.release(tail) + tail = tail.Prev() + } + } + f.mu.Unlock() + return res, done +} + +func (f *Fragmentation) release(r *reassembler) { + // Before releasing a fragment we need to check if r is already marked as done. + // Otherwise, we would delete it twice. + if r.checkDoneOrMark() { + return + } + + delete(f.reassemblers, r.id) + f.rList.Remove(r) + f.size -= r.size + if f.size < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) + f.size = 0 + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go new file mode 100644 index 000000000..2f0200d26 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -0,0 +1,166 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fragmentation + +import ( + "reflect" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +// vv is a helper to build VectorisedView from different strings. +func vv(size int, pieces ...string) *buffer.VectorisedView { + views := make([]buffer.View, len(pieces)) + for i, p := range pieces { + views[i] = []byte(p) + } + + vv := buffer.NewVectorisedView(size, views) + return &vv +} + +func emptyVv() *buffer.VectorisedView { + vv := buffer.NewVectorisedView(0, nil) + return &vv +} + +type processInput struct { + id uint32 + first uint16 + last uint16 + more bool + vv *buffer.VectorisedView +} + +type processOutput struct { + vv *buffer.VectorisedView + done bool +} + +var processTestCases = []struct { + comment string + in []processInput + out []processOutput +}{ + { + comment: "One ID", + in: []processInput{ + {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: emptyVv(), done: false}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, + { + comment: "Two IDs", + in: []processInput{ + {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, + {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, + {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: emptyVv(), done: false}, + {vv: emptyVv(), done: false}, + {vv: vv(4, "ab", "cd"), done: true}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, +} + +func TestFragmentationProcess(t *testing.T) { + for _, c := range processTestCases { + f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + for i, in := range c.in { + vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + if !reflect.DeepEqual(vv, *(c.out[i].vv)) { + t.Errorf("Test \"%s\" Process() returned a wrong vv. Got %v. Want %v", c.comment, vv, *(c.out[i].vv)) + } + if done != c.out[i].done { + t.Errorf("Test \"%s\" Process() returned a wrong done. Got %t. Want %t", c.comment, done, c.out[i].done) + } + if c.out[i].done { + if _, ok := f.reassemblers[in.id]; ok { + t.Errorf("Test \"%s\" Process() didn't remove buffer from reassemblers.", c.comment) + } + for n := f.rList.Front(); n != nil; n = n.Next() { + if n.id == in.id { + t.Errorf("Test \"%s\" Process() didn't remove buffer from rList.", c.comment) + } + } + } + } + } +} + +func TestReassemblingTimeout(t *testing.T) { + timeout := time.Millisecond + f := NewFragmentation(1024, 512, timeout) + // Send first fragment with id = 0, first = 0, last = 0, and more = true. + f.Process(0, 0, 0, true, vv(1, "0")) + // Sleep more than the timeout. + time.Sleep(2 * timeout) + // Send another fragment that completes a packet. + // However, no packet should be reassembled because the fragment arrived after the timeout. + _, done := f.Process(0, 1, 1, false, vv(1, "1")) + if done { + t.Errorf("Fragmentation does not respect the reassembling timeout.") + } +} + +func TestMemoryLimits(t *testing.T) { + f := NewFragmentation(3, 1, DefaultReassembleTimeout) + // Send first fragment with id = 0. + f.Process(0, 0, 0, true, vv(1, "0")) + // Send first fragment with id = 1. + f.Process(1, 0, 0, true, vv(1, "1")) + // Send first fragment with id = 2. + f.Process(2, 0, 0, true, vv(1, "2")) + + // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be + // evicted. + f.Process(3, 0, 0, true, vv(1, "3")) + + if _, ok := f.reassemblers[0]; ok { + t.Errorf("Memory limits are not respected: id=0 has not been evicted.") + } + if _, ok := f.reassemblers[1]; ok { + t.Errorf("Memory limits are not respected: id=1 has not been evicted.") + } + if _, ok := f.reassemblers[3]; !ok { + t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") + } +} + +func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { + f := NewFragmentation(1, 0, DefaultReassembleTimeout) + // Send first fragment with id = 0. + f.Process(0, 0, 0, true, vv(1, "0")) + // Send the same packet again. + f.Process(0, 0, 0, true, vv(1, "0")) + + got := f.size + want := 1 + if got != want { + t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) + } +} + +func TestFragmentationViewsDoNotEscape(t *testing.T) { + f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + in := vv(2, "0", "1") + f.Process(0, 0, 1, true, in) + // Modify input view. + in.RemoveFirst() + got, _ := f.Process(0, 2, 2, false, vv(1, "2")) + want := vv(3, "0", "1", "2") + if !reflect.DeepEqual(got, *want) { + t.Errorf("Process() returned a wrong vv. Got %v. Want %v", got, *want) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go new file mode 100644 index 000000000..0267a575d --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -0,0 +1,109 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fragmentation + +import ( + "container/heap" + "fmt" + "math" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +type hole struct { + first uint16 + last uint16 + deleted bool +} + +type reassembler struct { + reassemblerEntry + id uint32 + size int + mu sync.Mutex + holes []hole + deleted int + heap fragHeap + done bool + creationTime time.Time +} + +func newReassembler(id uint32) *reassembler { + r := &reassembler{ + id: id, + holes: make([]hole, 0, 16), + deleted: 0, + heap: make(fragHeap, 0, 8), + creationTime: time.Now(), + } + r.holes = append(r.holes, hole{ + first: 0, + last: math.MaxUint16, + deleted: false}) + return r +} + +// updateHoles updates the list of holes for an incoming fragment and +// returns true iff the fragment filled at least part of an existing hole. +func (r *reassembler) updateHoles(first, last uint16, more bool) bool { + used := false + for i := range r.holes { + if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { + continue + } + used = true + r.deleted++ + r.holes[i].deleted = true + if first > r.holes[i].first { + r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) + } + if last < r.holes[i].last && more { + r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) + } + } + return used +} + +func (r *reassembler) process(first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool, int) { + r.mu.Lock() + defer r.mu.Unlock() + consumed := 0 + if r.done { + // A concurrent goroutine might have already reassembled + // the packet and emptied the heap while this goroutine + // was waiting on the mutex. We don't have to do anything in this case. + return buffer.NewVectorisedView(0, nil), false, consumed + } + if r.updateHoles(first, last, more) { + // We store the incoming packet only if it filled some holes. + uu := vv.Clone(nil) + heap.Push(&r.heap, fragment{offset: first, vv: &uu}) + consumed = vv.Size() + r.size += consumed + } + // Check if all the holes have been deleted and we are ready to reassamble. + if r.deleted < len(r.holes) { + return buffer.NewVectorisedView(0, nil), false, consumed + } + res, err := r.heap.reassemble() + if err != nil { + panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) + } + return res, true, consumed +} + +func (r *reassembler) tooOld(timeout time.Duration) bool { + return time.Now().Sub(r.creationTime) > timeout +} + +func (r *reassembler) checkDoneOrMark() bool { + r.mu.Lock() + prev := r.done + r.done = true + r.mu.Unlock() + return prev +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go new file mode 100644 index 000000000..b64604383 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -0,0 +1,95 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fragmentation + +import ( + "math" + "reflect" + "testing" +) + +type updateHolesInput struct { + first uint16 + last uint16 + more bool +} + +var holesTestCases = []struct { + comment string + in []updateHolesInput + want []hole +}{ + { + comment: "No fragments. Expected holes: {[0 -> inf]}.", + in: []updateHolesInput{}, + want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, + }, + { + comment: "One fragment at beginning. Expected holes: {[2, inf]}.", + in: []updateHolesInput{{first: 0, last: 1, more: true}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 2, last: math.MaxUint16, deleted: false}, + }, + }, + { + comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", + in: []updateHolesInput{{first: 1, last: 2, more: true}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 0, last: 0, deleted: false}, + {first: 3, last: math.MaxUint16, deleted: false}, + }, + }, + { + comment: "One fragment at the end. Expected holes: {[0, 0]}.", + in: []updateHolesInput{{first: 1, last: 2, more: false}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 0, last: 0, deleted: false}, + }, + }, + { + comment: "One fragment completing a packet. Expected holes: {}.", + in: []updateHolesInput{{first: 0, last: 1, more: false}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + }, + }, + { + comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", + in: []updateHolesInput{ + {first: 0, last: 1, more: true}, + {first: 2, last: 3, more: false}, + }, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 2, last: math.MaxUint16, deleted: true}, + }, + }, + { + comment: "Two overlapping fragments completing a packet. Expected holes: {}.", + in: []updateHolesInput{ + {first: 0, last: 2, more: true}, + {first: 2, last: 3, more: false}, + }, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 3, last: math.MaxUint16, deleted: true}, + }, + }, +} + +func TestUpdateHoles(t *testing.T) { + for _, c := range holesTestCases { + r := newReassembler(0) + for _, i := range c.in { + r.updateHoles(i.first, i.last, i.more) + } + if !reflect.DeepEqual(r.holes, c.want) { + t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) + } + } +} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD new file mode 100644 index 000000000..96805c690 --- /dev/null +++ b/pkg/tcpip/network/hash/BUILD @@ -0,0 +1,11 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "hash", + srcs = ["hash.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash", + visibility = ["//visibility:public"], + deps = ["//pkg/tcpip/header"], +) diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go new file mode 100644 index 000000000..e5a696158 --- /dev/null +++ b/pkg/tcpip/network/hash/hash.go @@ -0,0 +1,83 @@ +// Copyright 2017 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hash contains utility functions for hashing. +package hash + +import ( + "crypto/rand" + "encoding/binary" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" +) + +var hashIV = RandN32(1)[0] + +// RandN32 generates a slice of n cryptographic random 32-bit numbers. +func RandN32(n int) []uint32 { + b := make([]byte, 4*n) + if _, err := rand.Read(b); err != nil { + panic("unable to get random numbers: " + err.Error()) + } + r := make([]uint32, n) + for i := range r { + r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)]) + } + return r +} + +// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted +// from linux. +func Hash3Words(a, b, c, initval uint32) uint32 { + const iv = 0xdeadbeef + (3 << 2) + initval += iv + + a += initval + b += initval + c += initval + + c ^= b + c -= rol32(b, 14) + a ^= c + a -= rol32(c, 11) + b ^= a + b -= rol32(a, 25) + c ^= b + c -= rol32(b, 16) + a ^= c + a -= rol32(c, 4) + b ^= a + b -= rol32(a, 14) + c ^= b + c -= rol32(b, 24) + + return c +} + +// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791. +func IPv4FragmentHash(h header.IPv4) uint32 { + x := uint32(h.ID())<<16 | uint32(h.Protocol()) + t := h.SourceAddress() + y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = h.DestinationAddress() + z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return Hash3Words(x, y, z, hashIV) +} + +// IPv6FragmentHash computes the hash of the ipv6 fragment. +// Unlike IPv4, the protocol is not used to compute the hash. +// RFC 2640 (sec 4.5) is not very sharp on this aspect. +// As a reference, also Linux ignores the protocol to compute +// the hash (inet6_hash_frag). +func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 { + t := h.SourceAddress() + y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = h.DestinationAddress() + z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return Hash3Words(f.ID(), y, z, hashIV) +} + +func rol32(v, shift uint32) uint32 { + return (v << shift) | (v >> ((-shift) & 31)) +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go new file mode 100644 index 000000000..797501858 --- /dev/null +++ b/pkg/tcpip/network/ip_test.go @@ -0,0 +1,560 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ip_test + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. +// The former is used to pretend that it's a link endpoint so that we can +// inspect packets written by the network endpoints. The latter is used to +// pretend that it's the network stack so that it can inspect incoming packets +// that have been handled by the network endpoints. +// +// Packets are checked by comparing their fields/values against the expected +// values stored in the test object itself. +type testObject struct { + t *testing.T + protocol tcpip.TransportProtocolNumber + contents []byte + srcAddr tcpip.Address + dstAddr tcpip.Address + v4 bool + typ stack.ControlType + extra uint32 + + dataCalls int + controlCalls int +} + +// checkValues verifies that the transport protocol, data contents, src & dst +// addresses of a packet match what's expected. If any field doesn't match, the +// test fails. +func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { + v := vv.ToView() + if protocol != t.protocol { + t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) + } + + if srcAddr != t.srcAddr { + t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) + } + + if dstAddr != t.dstAddr { + t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) + } + + if len(v) != len(t.contents) { + t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) + } + + for i := range t.contents { + if t.contents[i] != v[i] { + t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) + } + } +} + +// DeliverTransportPacket is called by network endpoints after parsing incoming +// packets. This is used by the test object to verify that the results of the +// parsing are expected. +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) { + t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) + t.dataCalls++ +} + +// DeliverTransportControlPacket is called by network endpoints after parsing +// incoming control (ICMP) packets. This is used by the test object to verify +// that the results of the parsing are expected. +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { + t.checkValues(trans, vv, remote, local) + if typ != t.typ { + t.t.Errorf("typ = %v, want %v", typ, t.typ) + } + if extra != t.extra { + t.t.Errorf("extra = %v, want %v", extra, t.extra) + } + t.controlCalls++ +} + +// Attach is only implemented to satisfy the LinkEndpoint interface. +func (*testObject) Attach(stack.NetworkDispatcher) {} + +// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that +// matches the linux loopback MTU. +func (*testObject) MTU() uint32 { + return 65536 +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*testObject) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. +func (*testObject) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testObject) LinkAddress() tcpip.LinkAddress { + return "" +} + +// WritePacket is called by network endpoints after producing a packet and +// writing it to the link endpoint. This is used by the test object to verify +// that the produced packet is as expected. +func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + var prot tcpip.TransportProtocolNumber + var srcAddr tcpip.Address + var dstAddr tcpip.Address + + if t.v4 { + h := header.IPv4(hdr.UsedBytes()) + prot = tcpip.TransportProtocolNumber(h.Protocol()) + srcAddr = h.SourceAddress() + dstAddr = h.DestinationAddress() + + } else { + h := header.IPv6(hdr.UsedBytes()) + prot = tcpip.TransportProtocolNumber(h.NextHeader()) + srcAddr = h.SourceAddress() + dstAddr = h.DestinationAddress() + } + var views [1]buffer.View + vv := payload.ToVectorisedView(views) + t.checkValues(prot, &vv, srcAddr, dstAddr) + return nil +} + +func TestIPv4Send(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, nil, &o) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Allocate and initialize the payload view. + payload := buffer.NewView(100) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + + // Allocate the header buffer. + hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + + // Issue the write. + o.protocol = 123 + o.srcAddr = "\x0a\x00\x00\x01" + o.dstAddr = "\x0a\x00\x00\x02" + o.contents = payload + + r := stack.Route{ + RemoteAddress: o.dstAddr, + LocalAddress: o.srcAddr, + } + if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } +} + +func TestIPv4Receive(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv4MinimumSize + 30 + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + SrcAddr: "\x0a\x00\x00\x02", + DstAddr: "\x0a\x00\x00\x01", + }) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = "\x0a\x00\x00\x02" + o.dstAddr = "\x0a\x00\x00\x01" + o.contents = view[header.IPv4MinimumSize:totalLen] + + r := stack.Route{ + LocalAddress: o.dstAddr, + RemoteAddress: o.srcAddr, + } + var views [1]buffer.View + vv := view.ToVectorisedView(views) + ep.HandlePacket(&r, &vv) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv4ReceiveControl(t *testing.T) { + const mtu = 0xbeef - header.IPv4MinimumSize + cases := []struct { + name string + expectedCount int + fragmentOffset uint16 + code uint8 + expectedTyp stack.ControlType + expectedExtra uint32 + trunc int + }{ + {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, + {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, + {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, + {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8}, + {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, + } + r := stack.Route{ + LocalAddress: "\x0a\x00\x00\x01", + RemoteAddress: "\x0a\x00\x00\xbb", + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var views [1]buffer.View + o := testObject{t: t} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4 + view := buffer.NewView(dataOffset + 8) + + // Create the outer IPv4 header. + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(view) - c.trunc), + TTL: 20, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: "\x0a\x00\x00\xbb", + DstAddr: "\x0a\x00\x00\x01", + }) + + // Create the ICMP header. + icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) + icmp.SetType(header.ICMPv4DstUnreachable) + icmp.SetCode(c.code) + copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + + // Create the inner IPv4 header. + ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:]) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: 100, + TTL: 20, + Protocol: 10, + FragmentOffset: c.fragmentOffset, + SrcAddr: "\x0a\x00\x00\x01", + DstAddr: "\x0a\x00\x00\x02", + }) + + // Make payload be non-zero. + for i := dataOffset; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to IPv4 endpoint, dispatcher will validate that + // it's ok. + o.protocol = 10 + o.srcAddr = "\x0a\x00\x00\x02" + o.dstAddr = "\x0a\x00\x00\x01" + o.contents = view[dataOffset:] + o.typ = c.expectedTyp + o.extra = c.expectedExtra + + vv := view.ToVectorisedView(views) + vv.CapLength(len(view) - c.trunc) + ep.HandlePacket(&r, &vv) + if want := c.expectedCount; o.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + } + }) + } +} + +func TestIPv4FragmentationReceive(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv4MinimumSize + 24 + + frag1 := buffer.NewView(totalLen) + ip1 := header.IPv4(frag1) + ip1.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + FragmentOffset: 0, + Flags: header.IPv4FlagMoreFragments, + SrcAddr: "\x0a\x00\x00\x02", + DstAddr: "\x0a\x00\x00\x01", + }) + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + frag1[i] = uint8(i) + } + + frag2 := buffer.NewView(totalLen) + ip2 := header.IPv4(frag2) + ip2.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + FragmentOffset: 24, + SrcAddr: "\x0a\x00\x00\x02", + DstAddr: "\x0a\x00\x00\x01", + }) + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + frag2[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = "\x0a\x00\x00\x02" + o.dstAddr = "\x0a\x00\x00\x01" + o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + + r := stack.Route{ + LocalAddress: o.dstAddr, + RemoteAddress: o.srcAddr, + } + + // Send first segment. + var views1 [1]buffer.View + vv1 := frag1.ToVectorisedView(views1) + ep.HandlePacket(&r, &vv1) + if o.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + } + + // Send second segment. + var views2 [1]buffer.View + vv2 := frag2.ToVectorisedView(views2) + ep.HandlePacket(&r, &vv2) + + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv6Send(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, nil, &o) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Allocate and initialize the payload view. + payload := buffer.NewView(100) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + + // Allocate the header buffer. + hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + + // Issue the write. + o.protocol = 123 + o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + o.contents = payload + + r := stack.Route{ + RemoteAddress: o.dstAddr, + LocalAddress: o.srcAddr, + } + if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } +} + +func TestIPv6Receive(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv6MinimumSize + 30 + view := buffer.NewView(totalLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(totalLen - header.IPv6MinimumSize), + NextHeader: 10, + HopLimit: 20, + SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02", + DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + }) + + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < totalLen; i++ { + view[i] = uint8(i) + } + + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + o.contents = view[header.IPv6MinimumSize:totalLen] + + r := stack.Route{ + LocalAddress: o.dstAddr, + RemoteAddress: o.srcAddr, + } + var views [1]buffer.View + vv := view.ToVectorisedView(views) + ep.HandlePacket(&r, &vv) + + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv6ReceiveControl(t *testing.T) { + newUint16 := func(v uint16) *uint16 { return &v } + + const mtu = 0xffff + cases := []struct { + name string + expectedCount int + fragmentOffset *uint16 + typ header.ICMPv6Type + code uint8 + expectedTyp stack.ControlType + expectedExtra uint32 + trunc int + }{ + {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, + {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, + {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, + {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, + {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, + {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, + } + r := stack.Route{ + LocalAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var views [1]buffer.View + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + defer ep.Close() + + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 + if c.fragmentOffset != nil { + dataOffset += header.IPv6FragmentHeaderSize + } + view := buffer.NewView(dataOffset + 8) + + // Create the outer IPv6 header. + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 20, + SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + }) + + // Create the ICMP header. + icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) + icmp.SetType(c.typ) + icmp.SetCode(c.code) + copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + NextHeader: 10, + HopLimit: 20, + SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02", + }) + + // Build the fragmentation header if needed. + if c.fragmentOffset != nil { + ip.SetNextHeader(header.IPv6FragmentHeader) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + frag.Encode(&header.IPv6FragmentFields{ + NextHeader: 10, + FragmentOffset: *c.fragmentOffset, + M: true, + Identification: 0x12345678, + }) + } + + // Make payload be non-zero. + for i := dataOffset; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to IPv6 endpoint, dispatcher will validate that + // it's ok. + o.protocol = 10 + o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + o.contents = view[dataOffset:] + o.typ = c.expectedTyp + o.extra = c.expectedExtra + + vv := view.ToVectorisedView(views) + vv.CapLength(len(view) - c.trunc) + ep.HandlePacket(&r, &vv) + if want := c.expectedCount; o.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD new file mode 100644 index 000000000..9df113df1 --- /dev/null +++ b/pkg/tcpip/network/ipv4/BUILD @@ -0,0 +1,38 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "ipv4", + srcs = [ + "icmp.go", + "ipv4.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +go_test( + name = "ipv4_test", + size = "small", + srcs = ["icmp_test.go"], + deps = [ + ":ipv4", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go new file mode 100644 index 000000000..ffd761350 --- /dev/null +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -0,0 +1,282 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipv4 + +import ( + "context" + "encoding/binary" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// PingProtocolName is a pseudo transport protocol used to handle ping replies. +// Use it when constructing a stack that intends to use ipv4.Ping. +const PingProtocolName = "icmpv4ping" + +// pingProtocolNumber is a fake transport protocol used to +// deliver incoming ICMP echo replies. The ICMP identifier +// number is used as a port number for multiplexing. +const pingProtocolNumber tcpip.TransportProtocolNumber = 256 + 11 + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { + h := header.IPv4(vv.First()) + + // We don't use IsValid() here because ICMP only requires that the IP + // header plus 8 bytes of the transport header be included. So it's + // likely that it is truncated, which would cause IsValid to return + // false. + // + // Drop packet if it doesn't have the basic IPv4 header or if the + // original source address doesn't match the endpoint's address. + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + return + } + + hlen := int(h.HeaderLength()) + if vv.Size() < hlen || h.FragmentOffset() != 0 { + // We won't be able to handle this if it doesn't contain the + // full IPv4 header, or if it's a fragment not at offset 0 + // (because it won't have the transport header). + return + } + + // Skip the ip header, then deliver control message. + vv.TrimFront(hlen) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) +} + +func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { + v := vv.First() + if len(v) < header.ICMPv4MinimumSize { + return + } + h := header.ICMPv4(v) + + switch h.Type() { + case header.ICMPv4Echo: + if len(v) < header.ICMPv4EchoMinimumSize { + return + } + vv.TrimFront(header.ICMPv4MinimumSize) + req := echoRequest{r: r.Clone(), v: vv.ToView()} + select { + case e.echoRequests <- req: + default: + req.r.Release() + } + + case header.ICMPv4EchoReply: + e.dispatcher.DeliverTransportPacket(r, pingProtocolNumber, vv) + + case header.ICMPv4DstUnreachable: + if len(v) < header.ICMPv4DstUnreachableMinimumSize { + return + } + vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) + switch h.Code() { + case header.ICMPv4PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, vv) + + case header.ICMPv4FragmentationNeeded: + mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + } + } + // TODO: Handle other ICMP types. +} + +type echoRequest struct { + r stack.Route + v buffer.View +} + +func (e *endpoint) echoReplier() { + for req := range e.echoRequests { + sendICMPv4(&req.r, header.ICMPv4EchoReply, 0, req.v) + req.r.Release() + } +} + +func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error { + hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4.SetType(typ) + icmpv4.SetCode(code) + icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + + return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber) +} + +// A Pinger can send echo requests to an address. +type Pinger struct { + Stack *stack.Stack + NICID tcpip.NICID + Addr tcpip.Address + LocalAddr tcpip.Address // optional + Wait time.Duration // if zero, defaults to 1 second + Count uint16 // if zero, defaults to MaxUint16 +} + +// Ping sends echo requests to an ICMPv4 endpoint. +// Responses are streamed to the channel ch. +func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error { + count := p.Count + if count == 0 { + count = 1<<16 - 1 + } + wait := p.Wait + if wait == 0 { + wait = 1 * time.Second + } + + r, err := p.Stack.FindRoute(p.NICID, p.LocalAddr, p.Addr, ProtocolNumber) + if err != nil { + return err + } + + netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber} + ep := &pingEndpoint{ + stack: p.Stack, + pktCh: make(chan buffer.View, 1), + } + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + RemoteAddress: p.Addr, + } + + _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) { + id.LocalPort = port + err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + if err != nil { + return err + } + defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id) + + v := buffer.NewView(4) + binary.BigEndian.PutUint16(v[0:], id.LocalPort) + + start := time.Now() + + done := make(chan struct{}) + go func(count int) { + loop: + for ; count > 0; count-- { + select { + case v := <-ep.pktCh: + seq := binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize+2:]) + ch <- PingReply{ + Duration: time.Since(start) - time.Duration(seq)*wait, + SeqNumber: seq, + } + case <-ctx.Done(): + break loop + } + } + close(done) + }(int(count)) + defer func() { <-done }() + + t := time.NewTicker(wait) + defer t.Stop() + for seq := uint16(0); seq < count; seq++ { + select { + case <-t.C: + case <-ctx.Done(): + return nil + } + binary.BigEndian.PutUint16(v[2:], seq) + sent := time.Now() + if err := sendICMPv4(&r, header.ICMPv4Echo, 0, v); err != nil { + ch <- PingReply{ + Error: err, + Duration: time.Since(sent), + SeqNumber: seq, + } + } + } + return nil +} + +// PingReply summarizes an ICMP echo reply. +type PingReply struct { + Error *tcpip.Error // reports any errors sending a ping request + Duration time.Duration + SeqNumber uint16 +} + +type pingProtocol struct{} + +func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrNotSupported // endpoints are created directly +} + +func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return pingProtocolNumber } + +func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize } + +func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + ident := binary.BigEndian.Uint16(v[4:]) + return 0, ident, nil +} + +func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *pingProtocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *pingProtocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(PingProtocolName, func() stack.TransportProtocol { + return &pingProtocol{} + }) +} + +type pingEndpoint struct { + stack *stack.Stack + pktCh chan buffer.View +} + +func (e *pingEndpoint) Close() { + close(e.pktCh) +} + +func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { + select { + case e.pktCh <- vv.ToView(): + default: + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *pingEndpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +} diff --git a/pkg/tcpip/network/ipv4/icmp_test.go b/pkg/tcpip/network/ipv4/icmp_test.go new file mode 100644 index 000000000..378fba74b --- /dev/null +++ b/pkg/tcpip/network/ipv4/icmp_test.go @@ -0,0 +1,124 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipv4_test + +import ( + "context" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const stackAddr = "\x0a\x00\x00\x01" + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack +} + +func newTestContext(t *testing.T) *testContext { + s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}) + + const defaultMTU = 65536 + id, linkEP := channel.New(256, defaultMTU, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + close(c.linkEP.C) +} + +func (c *testContext) loopback() { + go func() { + for pkt := range c.linkEP.C { + v := make(buffer.View, len(pkt.Header)+len(pkt.Payload)) + copy(v, pkt.Header) + copy(v[len(pkt.Header):], pkt.Payload) + vv := v.ToVectorisedView([1]buffer.View{}) + c.linkEP.Inject(pkt.Proto, &vv) + } + }() +} + +func TestEcho(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + c.loopback() + + ch := make(chan ipv4.PingReply, 1) + p := ipv4.Pinger{ + Stack: c.s, + NICID: 1, + Addr: stackAddr, + Wait: 10 * time.Millisecond, + Count: 1, // one ping only + } + if err := p.Ping(context.Background(), ch); err != nil { + t.Fatalf("icmp.Ping failed: %v", err) + } + + ping := <-ch + if ping.Error != nil { + t.Errorf("bad ping response: %v", ping.Error) + } +} + +func TestEchoSequence(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + c.loopback() + + const numPings = 3 + ch := make(chan ipv4.PingReply, numPings) + p := ipv4.Pinger{ + Stack: c.s, + NICID: 1, + Addr: stackAddr, + Wait: 10 * time.Millisecond, + Count: numPings, + } + if err := p.Ping(context.Background(), ch); err != nil { + t.Fatalf("icmp.Ping failed: %v", err) + } + + for i := uint16(0); i < numPings; i++ { + ping := <-ch + if ping.Error != nil { + t.Errorf("i=%d bad ping response: %v", i, ping.Error) + } + if ping.SeqNumber != i { + t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i) + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go new file mode 100644 index 000000000..4cc2a2fd4 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -0,0 +1,233 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ipv4 contains the implementation of the ipv4 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the +// network protocols when calling stack.New(). Then endpoints can be created +// by passing ipv4.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv4 + +import ( + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv4 protocol name. + ProtocolName = "ipv4" + + // ProtocolNumber is the ipv4 protocol number. + ProtocolNumber = header.IPv4ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // TotalLength field of the ipv4 header. + maxTotalSize = 0xffff + + // buckets is the number of identifier buckets. + buckets = 2048 +) + +type address [header.IPv4AddressSize]byte + +type endpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + address address + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher + echoRequests chan echoRequest + fragmentation *fragmentation.Fragmentation +} + +func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint { + e := &endpoint{ + nicid: nicid, + linkEP: linkEP, + dispatcher: dispatcher, + echoRequests: make(chan echoRequest, 10), + fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } + copy(e.address[:], addr) + e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])} + + go e.echoReplier() + + return e +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv4 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// MaxHeaderLength returns the maximum length needed by ipv4 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + length := uint16(hdr.UsedLength() + len(payload)) + id := uint32(0) + if length > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + } + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: length, + ID: uint16(id), + TTL: 65, + Protocol: uint8(protocol), + SrcAddr: tcpip.Address(e.address[:]), + DstAddr: r.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) +} + +// HandlePacket is called by the link layer when new ipv4 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { + h := header.IPv4(vv.First()) + if !h.IsValid(vv.Size()) { + return + } + + hlen := int(h.HeaderLength()) + tlen := int(h.TotalLength()) + vv.TrimFront(hlen) + vv.CapLength(tlen - hlen) + + more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 + if more || h.FragmentOffset() != 0 { + // The packet is a fragment, let's try to reassemble it. + last := h.FragmentOffset() + uint16(vv.Size()) - 1 + tt, ready := e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + if !ready { + return + } + vv = &tt + } + p := h.TransportProtocol() + if p == header.ICMPv4ProtocolNumber { + e.handleICMP(r, vv) + return + } + e.dispatcher.DeliverTransportPacket(r, p, vv) +} + +// Close cleans up resources associated with the endpoint. +func (e *endpoint) Close() { + close(e.echoRequests) +} + +type protocol struct{} + +// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv4 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv4 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv4MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv4(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint creates a new ipv4 endpoint. +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return newEndpoint(nicid, addr, dispatcher, linkEP), nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + if mtu > maxTotalSize { + mtu = maxTotalSize + } + return mtu - header.IPv4MinimumSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address, the transport protocol number, and a random initial +// value (generated once on initialization) to generate the hash. +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { + t := r.LocalAddress + a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = r.RemoteAddress + b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return hash.Hash3Words(a, b, uint32(protocol), hashIV) +} + +var ( + ids []uint32 + hashIV uint32 +) + +func init() { + ids = make([]uint32, buckets) + + // Randomly initialize hashIV and the ids. + r := hash.RandN32(1 + buckets) + for i := range ids { + ids[i] = r[i] + } + hashIV = r[buckets] + + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD new file mode 100644 index 000000000..db7da0af3 --- /dev/null +++ b/pkg/tcpip/network/ipv6/BUILD @@ -0,0 +1,21 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "ipv6", + srcs = [ + "icmp.go", + "ipv6.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go new file mode 100644 index 000000000..0fc6dcce2 --- /dev/null +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -0,0 +1,80 @@ +// Copyright 2017 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipv6 + +import ( + "encoding/binary" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { + h := header.IPv6(vv.First()) + + // We don't use IsValid() here because ICMP only requires that up to + // 1280 bytes of the original packet be included. So it's likely that it + // is truncated, which would cause IsValid to return false. + // + // Drop packet if it doesn't have the basic IPv6 header or if the + // original source address doesn't match the endpoint's address. + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + return + } + + // Skip the IP header, then handle the fragmentation header if there + // is one. + vv.TrimFront(header.IPv6MinimumSize) + p := h.TransportProtocol() + if p == header.IPv6FragmentHeader { + f := header.IPv6Fragment(vv.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { + // We can't handle fragments that aren't at offset 0 + // because they don't have the transport headers. + return + } + + // Skip fragmentation header and find out the actual protocol + // number. + vv.TrimFront(header.IPv6FragmentHeaderSize) + p = f.TransportProtocol() + } + + // Deliver the control packet to the transport endpoint. + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) +} + +func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { + v := vv.First() + if len(v) < header.ICMPv6MinimumSize { + return + } + h := header.ICMPv6(v) + + switch h.Type() { + case header.ICMPv6PacketTooBig: + if len(v) < header.ICMPv6PacketTooBigMinimumSize { + return + } + vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + + case header.ICMPv6DstUnreachable: + if len(v) < header.ICMPv6DstUnreachableMinimumSize { + return + } + vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + switch h.Code() { + case header.ICMPv6PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, vv) + } + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go new file mode 100644 index 000000000..15654cbbd --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -0,0 +1,172 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ipv6 contains the implementation of the ipv6 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the +// network protocols when calling stack.New(). Then endpoints can be created +// by passing ipv6.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv6 + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv6 protocol name. + ProtocolName = "ipv6" + + // ProtocolNumber is the ipv6 protocol number. + ProtocolNumber = header.IPv6ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // PayloadLength field of the ipv6 header. + maxPayloadSize = 0xffff +) + +type address [header.IPv6AddressSize]byte + +type endpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + address address + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher +} + +func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint { + e := &endpoint{nicid: nicid, linkEP: linkEP, dispatcher: dispatcher} + copy(e.address[:], addr) + e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])} + return e +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv6 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength returns the maximum length needed by ipv6 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + length := uint16(hdr.UsedLength()) + if payload != nil { + length += uint16(len(payload)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(protocol), + HopLimit: 65, + SrcAddr: tcpip.Address(e.address[:]), + DstAddr: r.RemoteAddress, + }) + + return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) +} + +// HandlePacket is called by the link layer when new ipv6 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { + h := header.IPv6(vv.First()) + if !h.IsValid(vv.Size()) { + return + } + + vv.TrimFront(header.IPv6MinimumSize) + vv.CapLength(int(h.PayloadLength())) + + p := h.TransportProtocol() + if p == header.ICMPv6ProtocolNumber { + e.handleICMP(r, vv) + return + } + + e.dispatcher.DeliverTransportPacket(r, p, vv) +} + +// Close cleans up resources associated with the endpoint. +func (*endpoint) Close() {} + +type protocol struct{} + +// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv6 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv6 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint creates a new ipv6 endpoint. +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return newEndpoint(nicid, addr, dispatcher, linkEP), nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + mtu -= header.IPv6MinimumSize + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} |