summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2019-06-02 06:44:55 +0000
committergVisor bot <gvisor-bot@google.com>2019-06-02 06:44:55 +0000
commitceb0d792f328d1fc0692197d8856a43c3936a571 (patch)
tree83155f302eff44a78bcc30a3a08f4efe59a79379 /pkg/tcpip
parentdeb7ecf1e46862d54f4b102f2d163cfbcfc37f3b (diff)
parent216da0b733dbed9aad9b2ab92ac75bcb906fd7ee (diff)
Merge 216da0b7 (automated)
Diffstat (limited to 'pkg/tcpip')
-rwxr-xr-xpkg/tcpip/buffer/buffer_state_autogen.go24
-rw-r--r--pkg/tcpip/buffer/prependable.go74
-rw-r--r--pkg/tcpip/buffer/view.go158
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins.go80
-rwxr-xr-xpkg/tcpip/hash/jenkins/jenkins_state_autogen.go4
-rw-r--r--pkg/tcpip/header/arp.go100
-rw-r--r--pkg/tcpip/header/checksum.go94
-rw-r--r--pkg/tcpip/header/eth.go74
-rw-r--r--pkg/tcpip/header/gue.go73
-rwxr-xr-xpkg/tcpip/header/header_state_autogen.go42
-rw-r--r--pkg/tcpip/header/icmpv4.go108
-rw-r--r--pkg/tcpip/header/icmpv6.go121
-rw-r--r--pkg/tcpip/header/interfaces.go92
-rw-r--r--pkg/tcpip/header/ipv4.go282
-rw-r--r--pkg/tcpip/header/ipv6.go248
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go146
-rw-r--r--pkg/tcpip/header/tcp.go543
-rw-r--r--pkg/tcpip/header/udp.go110
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go372
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go32
-rwxr-xr-xpkg/tcpip/link/fdbased/fdbased_state_autogen.go4
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go25
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go84
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go309
-rw-r--r--pkg/tcpip/link/loopback/loopback.go87
-rwxr-xr-xpkg/tcpip/link/loopback/loopback_state_autogen.go4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s40
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go60
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_unsafe.go29
-rw-r--r--pkg/tcpip/link/rawfile/errors.go70
-rwxr-xr-xpkg/tcpip/link/rawfile/rawfile_state_autogen.go4
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go182
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go66
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go408
-rwxr-xr-xpkg/tcpip/link/sniffer/sniffer_state_autogen.go4
-rw-r--r--pkg/tcpip/network/arp/arp.go203
-rwxr-xr-xpkg/tcpip/network/arp/arp_state_autogen.go4
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go134
-rwxr-xr-xpkg/tcpip/network/fragmentation/fragmentation_state_autogen.go38
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go118
-rwxr-xr-xpkg/tcpip/network/fragmentation/reassembler_list.go173
-rw-r--r--pkg/tcpip/network/hash/hash.go93
-rwxr-xr-xpkg/tcpip/network/hash/hash_state_autogen.go4
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go160
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go344
-rwxr-xr-xpkg/tcpip/network/ipv4/ipv4_state_autogen.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go297
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go207
-rwxr-xr-xpkg/tcpip/network/ipv6/ipv6_state_autogen.go4
-rw-r--r--pkg/tcpip/ports/ports.go209
-rwxr-xr-xpkg/tcpip/ports/ports_state_autogen.go4
-rw-r--r--pkg/tcpip/seqnum/seqnum.go67
-rwxr-xr-xpkg/tcpip/seqnum/seqnum_state_autogen.go4
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go306
-rw-r--r--pkg/tcpip/stack/nic.go728
-rw-r--r--pkg/tcpip/stack/registration.go441
-rw-r--r--pkg/tcpip/stack/route.go189
-rw-r--r--pkg/tcpip/stack/stack.go1095
-rw-r--r--pkg/tcpip/stack/stack_global_state.go19
-rwxr-xr-xpkg/tcpip/stack/stack_state_autogen.go59
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go420
-rw-r--r--pkg/tcpip/tcpip.go1055
-rwxr-xr-xpkg/tcpip/tcpip_state_autogen.go40
-rw-r--r--pkg/tcpip/time_unsafe.go45
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go710
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go90
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_state_autogen.go98
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go136
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go521
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go88
-rwxr-xr-xpkg/tcpip/transport/raw/packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/raw/raw_state_autogen.go96
-rw-r--r--pkg/tcpip/transport/tcp/accept.go499
-rw-r--r--pkg/tcpip/transport/tcp/connect.go1066
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go233
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go1741
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go362
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go171
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go250
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go221
-rw-r--r--pkg/tcpip/transport/tcp/reno.go103
-rw-r--r--pkg/tcpip/transport/tcp/sack.go99
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go306
-rw-r--r--pkg/tcpip/transport/tcp/segment.go186
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go46
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go79
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go82
-rw-r--r--pkg/tcpip/transport/tcp/snd.go1180
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go50
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_segment_list.go173
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_state_autogen.go400
-rw-r--r--pkg/tcpip/transport/tcp/timer.go141
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go1002
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go112
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go96
-rw-r--r--pkg/tcpip/transport/udp/protocol.go90
-rwxr-xr-xpkg/tcpip/transport/udp/udp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/udp/udp_state_autogen.go128
101 files changed, 21962 insertions, 0 deletions
diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go
new file mode 100755
index 000000000..7e51a28e8
--- /dev/null
+++ b/pkg/tcpip/buffer/buffer_state_autogen.go
@@ -0,0 +1,24 @@
+// automatically generated by stateify.
+
+package buffer
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *VectorisedView) beforeSave() {}
+func (x *VectorisedView) save(m state.Map) {
+ x.beforeSave()
+ m.Save("views", &x.views)
+ m.Save("size", &x.size)
+}
+
+func (x *VectorisedView) afterLoad() {}
+func (x *VectorisedView) load(m state.Map) {
+ m.Load("views", &x.views)
+ m.Load("size", &x.size)
+}
+
+func init() {
+ state.Register("buffer.VectorisedView", (*VectorisedView)(nil), state.Fns{Save: (*VectorisedView).save, Load: (*VectorisedView).load})
+}
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
new file mode 100644
index 000000000..4287464f3
--- /dev/null
+++ b/pkg/tcpip/buffer/prependable.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+// Prependable is a buffer that grows backwards, that is, more data can be
+// prepended to it. It is useful when building networking packets, where each
+// protocol adds its own headers to the front of the higher-level protocol
+// header and payload; for example, TCP would prepend its header to the payload,
+// then IP would prepend its own, then ethernet.
+type Prependable struct {
+ // Buf is the buffer backing the prependable buffer.
+ buf View
+
+ // usedIdx is the index where the used part of the buffer begins.
+ usedIdx int
+}
+
+// NewPrependable allocates a new prependable buffer with the given size.
+func NewPrependable(size int) Prependable {
+ return Prependable{buf: NewView(size), usedIdx: size}
+}
+
+// NewPrependableFromView creates an entirely-used Prependable from a View.
+//
+// NewPrependableFromView takes ownership of v. Note that since the entire
+// prependable is used, further attempts to call Prepend will note that size >
+// p.usedIdx and return nil.
+func NewPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: 0}
+}
+
+// View returns a View of the backing buffer that contains all prepended
+// data so far.
+func (p Prependable) View() View {
+ return p.buf[p.usedIdx:]
+}
+
+// UsedLength returns the number of bytes used so far.
+func (p Prependable) UsedLength() int {
+ return len(p.buf) - p.usedIdx
+}
+
+// AvailableLength returns the number of bytes used so far.
+func (p Prependable) AvailableLength() int {
+ return p.usedIdx
+}
+
+// TrimBack removes size bytes from the end.
+func (p *Prependable) TrimBack(size int) {
+ p.buf = p.buf[:len(p.buf)-size]
+}
+
+// Prepend reserves the requested space in front of the buffer, returning a
+// slice that represents the reserved space.
+func (p *Prependable) Prepend(size int) []byte {
+ if size > p.usedIdx {
+ return nil
+ }
+
+ p.usedIdx -= size
+ return p.View()[:size:size]
+}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
new file mode 100644
index 000000000..1a9d40778
--- /dev/null
+++ b/pkg/tcpip/buffer/view.go
@@ -0,0 +1,158 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+package buffer
+
+// View is a slice of a buffer, with convenience methods.
+type View []byte
+
+// NewView allocates a new buffer and returns an initialized view that covers
+// the whole buffer.
+func NewView(size int) View {
+ return make(View, size)
+}
+
+// NewViewFromBytes allocates a new buffer and copies in the given bytes.
+func NewViewFromBytes(b []byte) View {
+ return append(View(nil), b...)
+}
+
+// TrimFront removes the first "count" bytes from the visible section of the
+// buffer.
+func (v *View) TrimFront(count int) {
+ *v = (*v)[count:]
+}
+
+// CapLength irreversibly reduces the length of the visible section of the
+// buffer to the value specified.
+func (v *View) CapLength(length int) {
+ // We also set the slice cap because if we don't, one would be able to
+ // expand the view back to include the region just excluded. We want to
+ // prevent that to avoid potential data leak if we have uninitialized
+ // data in excluded region.
+ *v = (*v)[:length:length]
+}
+
+// ToVectorisedView returns a VectorisedView containing the receiver.
+func (v View) ToVectorisedView() VectorisedView {
+ return NewVectorisedView(len(v), []View{v})
+}
+
+// VectorisedView is a vectorised version of View using non contigous memory.
+// It supports all the convenience methods supported by View.
+//
+// +stateify savable
+type VectorisedView struct {
+ views []View
+ size int
+}
+
+// NewVectorisedView creates a new vectorised view from an already-allocated slice
+// of View and sets its size.
+func NewVectorisedView(size int, views []View) VectorisedView {
+ return VectorisedView{views: views, size: size}
+}
+
+// TrimFront removes the first "count" bytes of the vectorised view.
+func (vv *VectorisedView) TrimFront(count int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ vv.views[0].TrimFront(count)
+ return
+ }
+ count -= len(vv.views[0])
+ vv.RemoveFirst()
+ }
+}
+
+// CapLength irreversibly reduces the length of the vectorised view.
+func (vv *VectorisedView) CapLength(length int) {
+ if length < 0 {
+ length = 0
+ }
+ if vv.size < length {
+ return
+ }
+ vv.size = length
+ for i := range vv.views {
+ v := &vv.views[i]
+ if len(*v) >= length {
+ if length == 0 {
+ vv.views = vv.views[:i]
+ } else {
+ v.CapLength(length)
+ vv.views = vv.views[:i+1]
+ }
+ return
+ }
+ length -= len(*v)
+ }
+}
+
+// Clone returns a clone of this VectorisedView.
+// If the buffer argument is large enough to contain all the Views of this VectorisedView,
+// the method will avoid allocations and use the buffer to store the Views of the clone.
+func (vv VectorisedView) Clone(buffer []View) VectorisedView {
+ return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
+}
+
+// First returns the first view of the vectorised view.
+func (vv VectorisedView) First() View {
+ if len(vv.views) == 0 {
+ return nil
+ }
+ return vv.views[0]
+}
+
+// RemoveFirst removes the first view of the vectorised view.
+func (vv *VectorisedView) RemoveFirst() {
+ if len(vv.views) == 0 {
+ return
+ }
+ vv.size -= len(vv.views[0])
+ vv.views = vv.views[1:]
+}
+
+// Size returns the size in bytes of the entire content stored in the vectorised view.
+func (vv VectorisedView) Size() int {
+ return vv.size
+}
+
+// ToView returns a single view containing the content of the vectorised view.
+//
+// If the vectorised view contains a single view, that view will be returned
+// directly.
+func (vv VectorisedView) ToView() View {
+ if len(vv.views) == 1 {
+ return vv.views[0]
+ }
+ u := make([]byte, 0, vv.size)
+ for _, v := range vv.views {
+ u = append(u, v...)
+ }
+ return u
+}
+
+// Views returns the slice containing the all views.
+func (vv VectorisedView) Views() []View {
+ return vv.views
+}
+
+// Append appends the views in a vectorised view to this vectorised view.
+func (vv *VectorisedView) Append(vv2 VectorisedView) {
+ vv.views = append(vv.views, vv2.views...)
+ vv.size += vv2.size
+}
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go
new file mode 100644
index 000000000..52c22230e
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins.go
@@ -0,0 +1,80 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash
+// functions created by by Bob Jenkins.
+//
+// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1
+//
+package jenkins
+
+import (
+ "hash"
+)
+
+// Sum32 represents Jenkins's one_at_a_time hash.
+//
+// Use the Sum32 type directly (as opposed to New32 below)
+// to avoid allocations.
+type Sum32 uint32
+
+// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash.
+//
+// Its Sum method will lay the value out in big-endian byte order.
+func New32() hash.Hash32 {
+ var s Sum32
+ return &s
+}
+
+// Reset resets the hash to its initial state.
+func (s *Sum32) Reset() { *s = 0 }
+
+// Sum32 returns the hash value
+func (s *Sum32) Sum32() uint32 {
+ hash := *s
+
+ hash += (hash << 3)
+ hash ^= hash >> 11
+ hash += hash << 15
+
+ return uint32(hash)
+}
+
+// Write adds more data to the running hash.
+//
+// It never returns an error.
+func (s *Sum32) Write(data []byte) (int, error) {
+ hash := *s
+ for _, b := range data {
+ hash += Sum32(b)
+ hash += hash << 10
+ hash ^= hash >> 6
+ }
+ *s = hash
+ return len(data), nil
+}
+
+// Size returns the number of bytes Sum will return.
+func (s *Sum32) Size() int { return 4 }
+
+// BlockSize returns the hash's underlying block size.
+func (s *Sum32) BlockSize() int { return 1 }
+
+// Sum appends the current hash to in and returns the resulting slice.
+//
+// It does not change the underlying hash state.
+func (s *Sum32) Sum(in []byte) []byte {
+ v := s.Sum32()
+ return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
new file mode 100755
index 000000000..310f0ee6d
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package jenkins
+
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
new file mode 100644
index 000000000..55fe7292c
--- /dev/null
+++ b/pkg/tcpip/header/arp.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.googlesource.com/gvisor/pkg/tcpip"
+
+const (
+ // ARPProtocolNumber is the ARP network protocol number.
+ ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
+
+ // ARPSize is the size of an IPv4-over-Ethernet ARP packet.
+ ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+)
+
+// ARPOp is an ARP opcode.
+type ARPOp uint16
+
+// Typical ARP opcodes defined in RFC 826.
+const (
+ ARPRequest ARPOp = 1
+ ARPReply ARPOp = 2
+)
+
+// ARP is an ARP packet stored in a byte array as described in RFC 826.
+type ARP []byte
+
+func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
+func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
+func (a ARP) hardwareAddressSize() int { return int(a[4]) }
+func (a ARP) protocolAddressSize() int { return int(a[5]) }
+
+// Op is the ARP opcode.
+func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+
+// SetOp sets the ARP opcode.
+func (a ARP) SetOp(op ARPOp) {
+ a[6] = uint8(op >> 8)
+ a[7] = uint8(op)
+}
+
+// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
+func (a ARP) SetIPv4OverEthernet() {
+ a[0], a[1] = 0, 1 // htypeEthernet
+ a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
+ a[4] = 6 // macSize
+ a[5] = uint8(IPv4AddressSize)
+}
+
+// HardwareAddressSender is the link address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressSender() []byte {
+ const s = 8
+ return a[s : s+6]
+}
+
+// ProtocolAddressSender is the protocol address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressSender() []byte {
+ const s = 8 + 6
+ return a[s : s+4]
+}
+
+// HardwareAddressTarget is the link address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressTarget() []byte {
+ const s = 8 + 6 + 4
+ return a[s : s+6]
+}
+
+// ProtocolAddressTarget is the protocol address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressTarget() []byte {
+ const s = 8 + 6 + 4 + 6
+ return a[s : s+4]
+}
+
+// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
+func (a ARP) IsValid() bool {
+ if len(a) < ARPSize {
+ return false
+ }
+ const htypeEthernet = 1
+ const macSize = 6
+ return a.hardwareAddressSpace() == htypeEthernet &&
+ a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
+ a.hardwareAddressSize() == macSize &&
+ a.protocolAddressSize() == IPv4AddressSize
+}
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
new file mode 100644
index 000000000..2eaa7938a
--- /dev/null
+++ b/pkg/tcpip/header/checksum.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func calculateChecksum(buf []byte, initial uint32) uint16 {
+ v := initial
+
+ l := len(buf)
+ if l&1 != 0 {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+
+ for i := 0; i < l; i += 2 {
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16))
+}
+
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func Checksum(buf []byte, initial uint16) uint16 {
+ return calculateChecksum(buf, uint32(initial))
+}
+
+// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
+ var odd bool
+ sum := initial
+ for _, v := range vv.Views() {
+ if len(v) == 0 {
+ continue
+ }
+ s := uint32(sum)
+ if odd {
+ s += uint32(v[0])
+ v = v[1:]
+ }
+ odd = len(v)&1 != 0
+ sum = calculateChecksum(v, s)
+ }
+ return sum
+}
+
+// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of bytes.
+func ChecksumCombine(a, b uint16) uint16 {
+ v := uint32(a) + uint32(b)
+ return uint16(v + v>>16)
+}
+
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
+// destination protocol and network address. Pseudo-headers are needed by
+// transport layers when calculating their own checksum.
+func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
+ xsum := Checksum([]byte(srcAddr), 0)
+ xsum = Checksum([]byte(dstAddr), xsum)
+
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ xsum = Checksum(tmp, xsum)
+
+ return Checksum([]byte{0, uint8(protocol)}, xsum)
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
new file mode 100644
index 000000000..76143f454
--- /dev/null
+++ b/pkg/tcpip/header/eth.go
@@ -0,0 +1,74 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ dstMAC = 0
+ srcMAC = 6
+ ethType = 12
+)
+
+// EthernetFields contains the fields of an ethernet frame header. It is used to
+// describe the fields of a frame that needs to be encoded.
+type EthernetFields struct {
+ // SrcAddr is the "MAC source" field of an ethernet frame header.
+ SrcAddr tcpip.LinkAddress
+
+ // DstAddr is the "MAC destination" field of an ethernet frame header.
+ DstAddr tcpip.LinkAddress
+
+ // Type is the "ethertype" field of an ethernet frame header.
+ Type tcpip.NetworkProtocolNumber
+}
+
+// Ethernet represents an ethernet frame header stored in a byte array.
+type Ethernet []byte
+
+const (
+ // EthernetMinimumSize is the minimum size of a valid ethernet frame.
+ EthernetMinimumSize = 14
+
+ // EthernetAddressSize is the size, in bytes, of an ethernet address.
+ EthernetAddressSize = 6
+)
+
+// SourceAddress returns the "MAC source" field of the ethernet frame header.
+func (b Ethernet) SourceAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
+}
+
+// DestinationAddress returns the "MAC destination" field of the ethernet frame
+// header.
+func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
+}
+
+// Type returns the "ethertype" field of the ethernet frame header.
+func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
+}
+
+// Encode encodes all the fields of the ethernet frame header.
+func (b Ethernet) Encode(e *EthernetFields) {
+ binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
+ copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
+ copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
+}
diff --git a/pkg/tcpip/header/gue.go b/pkg/tcpip/header/gue.go
new file mode 100644
index 000000000..10d358c0e
--- /dev/null
+++ b/pkg/tcpip/header/gue.go
@@ -0,0 +1,73 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+const (
+ typeHLen = 0
+ encapProto = 1
+)
+
+// GUEFields contains the fields of a GUE packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type GUEFields struct {
+ // Type is the "type" field of the GUE header.
+ Type uint8
+
+ // Control is the "control" field of the GUE header.
+ Control bool
+
+ // HeaderLength is the "header length" field of the GUE header. It must
+ // be at least 4 octets, and a multiple of 4 as well.
+ HeaderLength uint8
+
+ // Protocol is the "protocol" field of the GUE header. This is one of
+ // the IPPROTO_* values.
+ Protocol uint8
+}
+
+// GUE represents a Generic UDP Encapsulation header stored in a byte array, the
+// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01.
+type GUE []byte
+
+const (
+ // GUEMinimumSize is the minimum size of a valid GUE packet.
+ GUEMinimumSize = 4
+)
+
+// TypeAndControl returns the GUE packet type (top 3 bits of the first byte,
+// which includes the control bit).
+func (b GUE) TypeAndControl() uint8 {
+ return b[typeHLen] >> 5
+}
+
+// HeaderLength returns the total length of the GUE header.
+func (b GUE) HeaderLength() uint8 {
+ return 4 + 4*(b[typeHLen]&0x1f)
+}
+
+// Protocol returns the protocol field of the GUE header.
+func (b GUE) Protocol() uint8 {
+ return b[encapProto]
+}
+
+// Encode encodes all the fields of the GUE header.
+func (b GUE) Encode(i *GUEFields) {
+ ctl := uint8(0)
+ if i.Control {
+ ctl = 1 << 5
+ }
+ b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4
+ b[encapProto] = i.Protocol
+}
diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go
new file mode 100755
index 000000000..a8f4c4693
--- /dev/null
+++ b/pkg/tcpip/header/header_state_autogen.go
@@ -0,0 +1,42 @@
+// automatically generated by stateify.
+
+package header
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SACKBlock) beforeSave() {}
+func (x *SACKBlock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *SACKBlock) afterLoad() {}
+func (x *SACKBlock) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *TCPOptions) beforeSave() {}
+func (x *TCPOptions) save(m state.Map) {
+ x.beforeSave()
+ m.Save("TS", &x.TS)
+ m.Save("TSVal", &x.TSVal)
+ m.Save("TSEcr", &x.TSEcr)
+ m.Save("SACKBlocks", &x.SACKBlocks)
+}
+
+func (x *TCPOptions) afterLoad() {}
+func (x *TCPOptions) load(m state.Map) {
+ m.Load("TS", &x.TS)
+ m.Load("TSVal", &x.TSVal)
+ m.Load("TSEcr", &x.TSEcr)
+ m.Load("SACKBlocks", &x.SACKBlocks)
+}
+
+func init() {
+ state.Register("header.SACKBlock", (*SACKBlock)(nil), state.Fns{Save: (*SACKBlock).save, Load: (*SACKBlock).load})
+ state.Register("header.TCPOptions", (*TCPOptions)(nil), state.Fns{Save: (*TCPOptions).save, Load: (*TCPOptions).load})
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
new file mode 100644
index 000000000..782e1053c
--- /dev/null
+++ b/pkg/tcpip/header/icmpv4.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ICMPv4 represents an ICMPv4 header stored in a byte array.
+type ICMPv4 []byte
+
+const (
+ // ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv4MinimumSize = 4
+
+ // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv4EchoMinimumSize = 6
+
+ // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4
+
+ // ICMPv4ProtocolNumber is the ICMP transport protocol number.
+ ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+)
+
+// ICMPv4Type is the ICMP type field described in RFC 792.
+type ICMPv4Type byte
+
+// Typical values of ICMPv4Type defined in RFC 792.
+const (
+ ICMPv4EchoReply ICMPv4Type = 0
+ ICMPv4DstUnreachable ICMPv4Type = 3
+ ICMPv4SrcQuench ICMPv4Type = 4
+ ICMPv4Redirect ICMPv4Type = 5
+ ICMPv4Echo ICMPv4Type = 8
+ ICMPv4TimeExceeded ICMPv4Type = 11
+ ICMPv4ParamProblem ICMPv4Type = 12
+ ICMPv4Timestamp ICMPv4Type = 13
+ ICMPv4TimestampReply ICMPv4Type = 14
+ ICMPv4InfoRequest ICMPv4Type = 15
+ ICMPv4InfoReply ICMPv4Type = 16
+)
+
+// Values for ICMP code as defined in RFC 792.
+const (
+ ICMPv4PortUnreachable = 3
+ ICMPv4FragmentationNeeded = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv4) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv4) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
+// SetChecksum sets the ICMP checksum field.
+func (b ICMPv4) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv4) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv4) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv4) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv4) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv4) Payload() []byte {
+ return b[ICMPv4MinimumSize:]
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
new file mode 100644
index 000000000..d0b10d849
--- /dev/null
+++ b/pkg/tcpip/header/icmpv6.go
@@ -0,0 +1,121 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ICMPv6 represents an ICMPv6 header stored in a byte array.
+type ICMPv6 []byte
+
+const (
+ // ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv6MinimumSize = 4
+
+ // ICMPv6ProtocolNumber is the ICMP transport protocol number.
+ ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
+
+ // ICMPv6NeighborSolicitMinimumSize is the minimum size of a
+ // neighbor solicitation packet.
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
+ ICMPv6NeighborAdvertSize = 32
+
+ // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv6EchoMinimumSize = 8
+
+ // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+
+ // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
+ // packet-too-big packet.
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
+)
+
+// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+type ICMPv6Type byte
+
+// Typical values of ICMPv6Type defined in RFC 4443.
+const (
+ ICMPv6DstUnreachable ICMPv6Type = 1
+ ICMPv6PacketTooBig ICMPv6Type = 2
+ ICMPv6TimeExceeded ICMPv6Type = 3
+ ICMPv6ParamProblem ICMPv6Type = 4
+ ICMPv6EchoRequest ICMPv6Type = 128
+ ICMPv6EchoReply ICMPv6Type = 129
+
+ // Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
+
+ ICMPv6RouterSolicit ICMPv6Type = 133
+ ICMPv6RouterAdvert ICMPv6Type = 134
+ ICMPv6NeighborSolicit ICMPv6Type = 135
+ ICMPv6NeighborAdvert ICMPv6Type = 136
+ ICMPv6RedirectMsg ICMPv6Type = 137
+)
+
+// Values for ICMP code as defined in RFC 4443.
+const (
+ ICMPv6PortUnreachable = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv6) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv6) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv6) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
+// SetChecksum calculates and sets the ICMP checksum field.
+func (b ICMPv6) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv6) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv6) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv6) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv6) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv6) Payload() []byte {
+ return b[ICMPv6MinimumSize:]
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
new file mode 100644
index 000000000..fb250ea30
--- /dev/null
+++ b/pkg/tcpip/header/interfaces.go
@@ -0,0 +1,92 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ // MaxIPPacketSize is the maximum supported IP packet size, excluding
+ // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
+ // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
+ // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
+ // m is the minimum IPv6 header size; we leave room for some potential
+ // IP options.
+ MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
+)
+
+// Transport offers generic methods to query and/or update the fields of the
+// header of a transport protocol buffer.
+type Transport interface {
+ // SourcePort returns the value of the "source port" field.
+ SourcePort() uint16
+
+ // Destination returns the value of the "destination port" field.
+ DestinationPort() uint16
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourcePort sets the value of the "source port" field.
+ SetSourcePort(uint16)
+
+ // SetDestinationPort sets the value of the "destination port" field.
+ SetDestinationPort(uint16)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // Payload returns the data carried in the transport buffer.
+ Payload() []byte
+}
+
+// Network offers generic methods to query and/or update the fields of the
+// header of a network protocol buffer.
+type Network interface {
+ // SourceAddress returns the value of the "source address" field.
+ SourceAddress() tcpip.Address
+
+ // DestinationAddress returns the value of the "destination address"
+ // field.
+ DestinationAddress() tcpip.Address
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourceAddress sets the value of the "source address" field.
+ SetSourceAddress(tcpip.Address)
+
+ // SetDestinationAddress sets the value of the "destination address"
+ // field.
+ SetDestinationAddress(tcpip.Address)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // TransportProtocol returns the number of the transport protocol
+ // stored in the payload.
+ TransportProtocol() tcpip.TransportProtocolNumber
+
+ // Payload returns a byte slice containing the payload of the network
+ // packet.
+ Payload() []byte
+
+ // TOS returns the values of the "type of service" and "flow label" fields.
+ TOS() (uint8, uint32)
+
+ // SetTOS sets the values of the "type of service" and "flow label" fields.
+ SetTOS(t uint8, l uint32)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
new file mode 100644
index 000000000..96e461491
--- /dev/null
+++ b/pkg/tcpip/header/ipv4.go
@@ -0,0 +1,282 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
+)
+
+// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv4Fields struct {
+ // IHL is the "internet header length" field of an IPv4 packet.
+ IHL uint8
+
+ // TOS is the "type of service" field of an IPv4 packet.
+ TOS uint8
+
+ // TotalLength is the "total length" field of an IPv4 packet.
+ TotalLength uint16
+
+ // ID is the "identification" field of an IPv4 packet.
+ ID uint16
+
+ // Flags is the "flags" field of an IPv4 packet.
+ Flags uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv4 packet.
+ FragmentOffset uint16
+
+ // TTL is the "time to live" field of an IPv4 packet.
+ TTL uint8
+
+ // Protocol is the "protocol" field of an IPv4 packet.
+ Protocol uint8
+
+ // Checksum is the "checksum" field of an IPv4 packet.
+ Checksum uint16
+
+ // SrcAddr is the "source ip address" of an IPv4 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv4 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv4 represents an ipv4 header stored in a byte array.
+// Most of the methods of IPv4 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv4 before using other methods.
+type IPv4 []byte
+
+const (
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ IPv4MinimumSize = 20
+
+ // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // that there are only 4 bits to represents the header length in 32-bit
+ // units, the header cannot exceed 15*4 = 60 bytes.
+ IPv4MaximumHeaderSize = 60
+
+ // IPv4AddressSize is the size, in bytes, of an IPv4 address.
+ IPv4AddressSize = 4
+
+ // IPv4ProtocolNumber is IPv4's network protocol number.
+ IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
+
+ // IPv4Version is the version of the ipv4 protocol.
+ IPv4Version = 4
+
+ // IPv4Broadcast is the broadcast address of the IPv4 procotol.
+ IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
+
+ // IPv4Any is the non-routable IPv4 "any" meta address.
+ IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+)
+
+// Flags that may be set in an IPv4 packet.
+const (
+ IPv4FlagMoreFragments = 1 << iota
+ IPv4FlagDontFragment
+)
+
+// IPVersion returns the version of IP used in the given packet. It returns -1
+// if the packet is not large enough to contain the version field.
+func IPVersion(b []byte) int {
+ // Length must be at least offset+length of version field.
+ if len(b) < versIHL+1 {
+ return -1
+ }
+ return int(b[versIHL] >> 4)
+}
+
+// HeaderLength returns the value of the "header length" field of the ipv4
+// header.
+func (b IPv4) HeaderLength() uint8 {
+ return (b[versIHL] & 0xf) * 4
+}
+
+// ID returns the value of the identifier field of the ipv4 header.
+func (b IPv4) ID() uint16 {
+ return binary.BigEndian.Uint16(b[id:])
+}
+
+// Protocol returns the value of the protocol field of the ipv4 header.
+func (b IPv4) Protocol() uint8 {
+ return b[protocol]
+}
+
+// Flags returns the "flags" field of the ipv4 header.
+func (b IPv4) Flags() uint8 {
+ return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
+}
+
+// TTL returns the "TTL" field of the ipv4 header.
+func (b IPv4) TTL() uint8 {
+ return b[ttl]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+func (b IPv4) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[flagsFO:]) << 3
+}
+
+// TotalLength returns the "total length" field of the ipv4 header.
+func (b IPv4) TotalLength() uint16 {
+ return binary.BigEndian.Uint16(b[totalLen:])
+}
+
+// Checksum returns the checksum field of the ipv4 header.
+func (b IPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[checksum:])
+}
+
+// SourceAddress returns the "source address" field of the ipv4 header.
+func (b IPv4) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv4
+// header.
+func (b IPv4) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.Protocol())
+}
+
+// Payload implements Network.Payload.
+func (b IPv4) Payload() []byte {
+ return b[b.HeaderLength():][:b.PayloadLength()]
+}
+
+// PayloadLength returns the length of the payload portion of the ipv4 packet.
+func (b IPv4) PayloadLength() uint16 {
+ return b.TotalLength() - uint16(b.HeaderLength())
+}
+
+// TOS returns the "type of service" field of the ipv4 header.
+func (b IPv4) TOS() (uint8, uint32) {
+ return b[tos], 0
+}
+
+// SetTOS sets the "type of service" field of the ipv4 header.
+func (b IPv4) SetTOS(v uint8, _ uint32) {
+ b[tos] = v
+}
+
+// SetTotalLength sets the "total length" field of the ipv4 header.
+func (b IPv4) SetTotalLength(totalLength uint16) {
+ binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+}
+
+// SetChecksum sets the checksum field of the ipv4 header.
+func (b IPv4) SetChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksum:], v)
+}
+
+// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// ipv4 header.
+func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+ v := (uint16(flags) << 13) | (offset >> 3)
+ binary.BigEndian.PutUint16(b[flagsFO:], v)
+}
+
+// SetID sets the identification field.
+func (b IPv4) SetID(v uint16) {
+ binary.BigEndian.PutUint16(b[id:], v)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv4 header.
+func (b IPv4) SetSourceAddress(addr tcpip.Address) {
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv4
+// header.
+func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
+}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b IPv4) CalculateChecksum() uint16 {
+ return Checksum(b[:b.HeaderLength()], 0)
+}
+
+// Encode encodes all the fields of the ipv4 header.
+func (b IPv4) Encode(i *IPv4Fields) {
+ b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b[tos] = i.TOS
+ b.SetTotalLength(i.TotalLength)
+ binary.BigEndian.PutUint16(b[id:], i.ID)
+ b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b[ttl] = i.TTL
+ b[protocol] = i.Protocol
+ b.SetChecksum(i.Checksum)
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+}
+
+// EncodePartial updates the total length and checksum fields of ipv4 header,
+// taking in the partial checksum, which is the checksum of the header without
+// the total length and checksum fields. It is useful in cases when similar
+// packets are produced.
+func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
+ b.SetTotalLength(totalLength)
+ checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
+ b.SetChecksum(^checksum)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv4) IsValid(pktSize int) bool {
+ if len(b) < IPv4MinimumSize {
+ return false
+ }
+
+ hlen := int(b.HeaderLength())
+ tlen := int(b.TotalLength())
+ if hlen > tlen || tlen > pktSize {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
+// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
+// will be 1110 = 0xe0.
+func IsV4MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return (addr[0] & 0xf0) == 0xe0
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
new file mode 100644
index 000000000..66820a466
--- /dev/null
+++ b/pkg/tcpip/header/ipv6.go
@@ -0,0 +1,248 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ versTCFL = 0
+ payloadLen = 4
+ nextHdr = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = 24
+)
+
+// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6Fields struct {
+ // TrafficClass is the "traffic class" field of an IPv6 packet.
+ TrafficClass uint8
+
+ // FlowLabel is the "flow label" field of an IPv6 packet.
+ FlowLabel uint32
+
+ // PayloadLength is the "payload length" field of an IPv6 packet.
+ PayloadLength uint16
+
+ // NextHeader is the "next header" field of an IPv6 packet.
+ NextHeader uint8
+
+ // HopLimit is the "hop limit" field of an IPv6 packet.
+ HopLimit uint8
+
+ // SrcAddr is the "source ip address" of an IPv6 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv6 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv6 represents an ipv6 header stored in a byte array.
+// Most of the methods of IPv6 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6 before using other methods.
+type IPv6 []byte
+
+const (
+ // IPv6MinimumSize is the minimum size of a valid IPv6 packet.
+ IPv6MinimumSize = 40
+
+ // IPv6AddressSize is the size, in bytes, of an IPv6 address.
+ IPv6AddressSize = 16
+
+ // IPv6ProtocolNumber is IPv6's network protocol number.
+ IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
+
+ // IPv6Version is the version of the ipv6 protocol.
+ IPv6Version = 6
+
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
+ // section 5.
+ IPv6MinimumMTU = 1280
+
+ // IPv6Any is the non-routable IPv6 "any" meta address.
+ IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+)
+
+// PayloadLength returns the value of the "payload length" field of the ipv6
+// header.
+func (b IPv6) PayloadLength() uint16 {
+ return binary.BigEndian.Uint16(b[payloadLen:])
+}
+
+// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+func (b IPv6) HopLimit() uint8 {
+ return b[hopLimit]
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 header.
+func (b IPv6) NextHeader() uint8 {
+ return b[nextHdr]
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// Payload implements Network.Payload.
+func (b IPv6) Payload() []byte {
+ return b[IPv6MinimumSize:][:b.PayloadLength()]
+}
+
+// SourceAddress returns the "source address" field of the ipv6 header.
+func (b IPv6) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv6
+// header.
+func (b IPv6) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+}
+
+// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
+// checksum, it just returns 0.
+func (IPv6) Checksum() uint16 {
+ return 0
+}
+
+// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) TOS() (uint8, uint32) {
+ v := binary.BigEndian.Uint32(b[versTCFL:])
+ return uint8(v >> 20), v & 0xfffff
+}
+
+// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) SetTOS(t uint8, l uint32) {
+ vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
+ binary.BigEndian.PutUint32(b[versTCFL:], vtf)
+}
+
+// SetPayloadLength sets the "payload length" field of the ipv6 header.
+func (b IPv6) SetPayloadLength(payloadLength uint16) {
+ binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv6 header.
+func (b IPv6) SetSourceAddress(addr tcpip.Address) {
+ copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv6
+// header.
+func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+}
+
+// SetNextHeader sets the value of the "next header" field of the ipv6 header.
+func (b IPv6) SetNextHeader(v uint8) {
+ b[nextHdr] = v
+}
+
+// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
+// checksum, it is empty.
+func (IPv6) SetChecksum(uint16) {
+}
+
+// Encode encodes all the fields of the ipv6 header.
+func (b IPv6) Encode(i *IPv6Fields) {
+ b.SetTOS(i.TrafficClass, i.FlowLabel)
+ b.SetPayloadLength(i.PayloadLength)
+ b[nextHdr] = i.NextHeader
+ b[hopLimit] = i.HopLimit
+ copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
+ copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv6) IsValid(pktSize int) bool {
+ if len(b) < IPv6MinimumSize {
+ return false
+ }
+
+ dlen := int(b.PayloadLength())
+ if dlen > pktSize-IPv6MinimumSize {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MappedAddress determines if the provided address is an IPv4 mapped
+// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
+func IsV4MappedAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+}
+
+// IsV6MulticastAddress determines if the provided address is an IPv6
+// multicast address (anything starting with FF).
+func IsV6MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xff
+}
+
+// SolicitedNodeAddr computes the solicited-node multicast address. This is
+// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
+// address.
+func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
+ const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+ return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
+}
+
+// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
+// (MAC) address.
+func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
+ // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local
+ // header, FE80::.
+ //
+ // The conversion is very nearly:
+ // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
+ // Note the capital A. The conversion aa->Aa involves a bit flip.
+ lladdrb := [16]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ 8: linkAddr[0] ^ 2,
+ 9: linkAddr[1],
+ 10: linkAddr[2],
+ 11: 0xFF,
+ 12: 0xFE,
+ 13: linkAddr[3],
+ 14: linkAddr[4],
+ 15: linkAddr[5],
+ }
+ return tcpip.Address(lladdrb[:])
+}
+
+// IsV6LinkLocalAddress determines if the provided address is an IPv6
+// link-local address (fe80::/10).
+func IsV6LinkLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
new file mode 100644
index 000000000..6d896355a
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -0,0 +1,146 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ nextHdrFrag = 0
+ fragOff = 2
+ more = 3
+ idV6 = 4
+)
+
+// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6FragmentFields struct {
+ // NextHeader is the "next header" field of an IPv6 fragment.
+ NextHeader uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv6 fragment.
+ FragmentOffset uint16
+
+ // M is the "more" field of an IPv6 fragment.
+ M bool
+
+ // Identification is the "identification" field of an IPv6 fragment.
+ Identification uint32
+}
+
+// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
+// Most of the methods of IPv6Fragment access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
+type IPv6Fragment []byte
+
+const (
+ // IPv6FragmentHeader header is the number used to specify that the next
+ // header is a fragment header, per RFC 2460.
+ IPv6FragmentHeader = 44
+
+ // IPv6FragmentHeaderSize is the size of the fragment header.
+ IPv6FragmentHeaderSize = 8
+)
+
+// Encode encodes all the fields of the ipv6 fragment.
+func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
+ b[nextHdrFrag] = i.NextHeader
+ binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
+ if i.M {
+ b[more] |= 1
+ }
+ binary.BigEndian.PutUint32(b[idV6:], i.Identification)
+}
+
+// IsValid performs basic validation on the fragment header.
+func (b IPv6Fragment) IsValid() bool {
+ return len(b) >= IPv6FragmentHeaderSize
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 fragment.
+func (b IPv6Fragment) NextHeader() uint8 {
+ return b[nextHdrFrag]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
+func (b IPv6Fragment) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[fragOff:]) >> 3
+}
+
+// More returns the "more" field of the ipv6 fragment.
+func (b IPv6Fragment) More() bool {
+ return b[more]&1 > 0
+}
+
+// Payload implements Network.Payload.
+func (b IPv6Fragment) Payload() []byte {
+ return b[IPv6FragmentHeaderSize:]
+}
+
+// ID returns the value of the identifier field of the ipv6 fragment.
+func (b IPv6Fragment) ID() uint32 {
+ return binary.BigEndian.Uint32(b[idV6:])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// The functions below have been added only to satisfy the Network interface.
+
+// Checksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) Checksum() uint16 {
+ panic("not supported")
+}
+
+// SourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SourceAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// DestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) DestinationAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// SetSourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetDestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetChecksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetChecksum(uint16) {
+ panic("not supported")
+}
+
+// TOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) TOS() (uint8, uint32) {
+ panic("not supported")
+}
+
+// SetTOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
+ panic("not supported")
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
new file mode 100644
index 000000000..0cd89b992
--- /dev/null
+++ b/pkg/tcpip/header/tcp.go
@@ -0,0 +1,543 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "github.com/google/btree"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// These constants are the offsets of the respective fields in the TCP header.
+const (
+ TCPSrcPortOffset = 0
+ TCPDstPortOffset = 2
+ TCPSeqNumOffset = 4
+ TCPAckNumOffset = 8
+ TCPDataOffset = 12
+ TCPFlagsOffset = 13
+ TCPWinSizeOffset = 14
+ TCPChecksumOffset = 16
+ TCPUrgentPtrOffset = 18
+)
+
+const (
+ // MaxWndScale is maximum allowed window scaling, as described in
+ // RFC 1323, section 2.3, page 11.
+ MaxWndScale = 14
+
+ // TCPMaxSACKBlocks is the maximum number of SACK blocks that can
+ // be encoded in a TCP option field.
+ TCPMaxSACKBlocks = 4
+)
+
+// Flags that may be set in a TCP segment.
+const (
+ TCPFlagFin = 1 << iota
+ TCPFlagSyn
+ TCPFlagRst
+ TCPFlagPsh
+ TCPFlagAck
+ TCPFlagUrg
+)
+
+// Options that may be present in a TCP segment.
+const (
+ TCPOptionEOL = 0
+ TCPOptionNOP = 1
+ TCPOptionMSS = 2
+ TCPOptionWS = 3
+ TCPOptionTS = 8
+ TCPOptionSACKPermitted = 4
+ TCPOptionSACK = 5
+)
+
+// TCPFields contains the fields of a TCP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type TCPFields struct {
+ // SrcPort is the "source port" field of a TCP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a TCP packet.
+ DstPort uint16
+
+ // SeqNum is the "sequence number" field of a TCP packet.
+ SeqNum uint32
+
+ // AckNum is the "acknowledgement number" field of a TCP packet.
+ AckNum uint32
+
+ // DataOffset is the "data offset" field of a TCP packet.
+ DataOffset uint8
+
+ // Flags is the "flags" field of a TCP packet.
+ Flags uint8
+
+ // WindowSize is the "window size" field of a TCP packet.
+ WindowSize uint16
+
+ // Checksum is the "checksum" field of a TCP packet.
+ Checksum uint16
+
+ // UrgentPointer is the "urgent pointer" field of a TCP packet.
+ UrgentPointer uint16
+}
+
+// TCPSynOptions is used to return the parsed TCP Options in a syn
+// segment.
+type TCPSynOptions struct {
+ // MSS is the maximum segment size provided by the peer in the SYN.
+ MSS uint16
+
+ // WS is the window scale option provided by the peer in the SYN.
+ //
+ // Set to -1 if no window scale option was provided.
+ WS int
+
+ // TS is true if the timestamp option was provided in the syn/syn-ack.
+ TS bool
+
+ // TSVal is the value of the TSVal field in the timestamp option.
+ TSVal uint32
+
+ // TSEcr is the value of the TSEcr field in the timestamp option.
+ TSEcr uint32
+
+ // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
+ SACKPermitted bool
+}
+
+// SACKBlock represents a single contiguous SACK block.
+//
+// +stateify savable
+type SACKBlock struct {
+ // Start indicates the lowest sequence number in the block.
+ Start seqnum.Value
+
+ // End indicates the sequence number immediately following the last
+ // sequence number of this block.
+ End seqnum.Value
+}
+
+// Less returns true if r.Start < b.Start.
+func (r SACKBlock) Less(b btree.Item) bool {
+ return r.Start.LessThan(b.(SACKBlock).Start)
+}
+
+// Contains returns true if b is completely contained in r.
+func (r SACKBlock) Contains(b SACKBlock) bool {
+ return r.Start.LessThanEq(b.Start) && b.End.LessThanEq(r.End)
+}
+
+// TCPOptions are used to parse and cache the TCP segment options for a non
+// syn/syn-ack segment.
+//
+// +stateify savable
+type TCPOptions struct {
+ // TS is true if the TimeStamp option is enabled.
+ TS bool
+
+ // TSVal is the value in the TSVal field of the segment.
+ TSVal uint32
+
+ // TSEcr is the value in the TSEcr field of the segment.
+ TSEcr uint32
+
+ // SACKBlocks are the SACK blocks specified in the segment.
+ SACKBlocks []SACKBlock
+}
+
+// TCP represents a TCP header stored in a byte array.
+type TCP []byte
+
+const (
+ // TCPMinimumSize is the minimum size of a valid TCP packet.
+ TCPMinimumSize = 20
+
+ // TCPOptionsMaximumSize is the maximum size of TCP options.
+ TCPOptionsMaximumSize = 40
+
+ // TCPHeaderMaximumSize is the maximum header size of a TCP packet.
+ TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize
+
+ // TCPProtocolNumber is TCP's transport protocol number.
+ TCPProtocolNumber tcpip.TransportProtocolNumber = 6
+)
+
+// SourcePort returns the "source port" field of the tcp header.
+func (b TCP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
+}
+
+// DestinationPort returns the "destination port" field of the tcp header.
+func (b TCP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
+}
+
+// SequenceNumber returns the "sequence number" field of the tcp header.
+func (b TCP) SequenceNumber() uint32 {
+ return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
+}
+
+// AckNumber returns the "ack number" field of the tcp header.
+func (b TCP) AckNumber() uint32 {
+ return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
+}
+
+// DataOffset returns the "data offset" field of the tcp header.
+func (b TCP) DataOffset() uint8 {
+ return (b[TCPDataOffset] >> 4) * 4
+}
+
+// Payload returns the data in the tcp packet.
+func (b TCP) Payload() []byte {
+ return b[b.DataOffset():]
+}
+
+// Flags returns the flags field of the tcp header.
+func (b TCP) Flags() uint8 {
+ return b[TCPFlagsOffset]
+}
+
+// WindowSize returns the "window size" field of the tcp header.
+func (b TCP) WindowSize() uint16 {
+ return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
+}
+
+// Checksum returns the "checksum" field of the tcp header.
+func (b TCP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
+}
+
+// SetSourcePort sets the "source port" field of the tcp header.
+func (b TCP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the tcp header.
+func (b TCP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
+}
+
+// SetChecksum sets the checksum field of the tcp header.
+func (b TCP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
+}
+
+// CalculateChecksum calculates the checksum of the tcp segment.
+// partialChecksum is the checksum of the network-layer pseudo-header
+// and the checksum of the segment data.
+func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
+ // Calculate the rest of the checksum.
+ return Checksum(b[:b.DataOffset()], partialChecksum)
+}
+
+// Options returns a slice that holds the unparsed TCP options in the segment.
+func (b TCP) Options() []byte {
+ return b[TCPMinimumSize:b.DataOffset()]
+}
+
+// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
+// option values in the TCP segment. NOTE: Invoking this function repeatedly is
+// expensive as it reparses the options on each invocation.
+func (b TCP) ParsedOptions() TCPOptions {
+ return ParseTCPOptions(b.Options())
+}
+
+func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
+ b[TCPFlagsOffset] = flags
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// Encode encodes all the fields of the tcp header.
+func (b TCP) Encode(t *TCPFields) {
+ b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
+ binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort)
+ binary.BigEndian.PutUint16(b[TCPDstPortOffset:], t.DstPort)
+ b[TCPDataOffset] = (t.DataOffset / 4) << 4
+ binary.BigEndian.PutUint16(b[TCPChecksumOffset:], t.Checksum)
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer)
+}
+
+// EncodePartial updates a subset of the fields of the tcp header. It is useful
+// in cases when similar segments are produced.
+func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
+ // Add the total length and "flags" field contributions to the checksum.
+ // We don't use the flags field directly from the header because it's a
+ // one-byte field with an odd offset, so it would be accounted for
+ // incorrectly by the Checksum routine.
+ tmp := make([]byte, 4)
+ binary.BigEndian.PutUint16(tmp, length)
+ binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Encode the passed-in fields.
+ b.encodeSubset(seqnum, acknum, flags, rcvwnd)
+
+ // Add the contributions of the passed-in fields to the checksum.
+ checksum = Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], checksum)
+ checksum = Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], checksum)
+
+ // Encode the checksum.
+ b.SetChecksum(^checksum)
+}
+
+// ParseSynOptions parses the options received in a SYN segment and returns the
+// relevant ones. opts should point to the option part of the TCP Header.
+func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
+ limit := len(opts)
+
+ synOpts := TCPSynOptions{
+ // Per RFC 1122, page 85: "If an MSS option is not received at
+ // connection setup, TCP MUST assume a default send MSS of 536."
+ MSS: 536,
+ // If no window scale option is specified, WS in options is
+ // returned as -1; this is because the absence of the option
+ // indicates that the we cannot use window scaling on the
+ // receive end either.
+ WS: -1,
+ }
+
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionMSS:
+ if i+4 > limit || opts[i+1] != 4 {
+ return synOpts
+ }
+ mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
+ if mss == 0 {
+ return synOpts
+ }
+ synOpts.MSS = mss
+ i += 4
+
+ case TCPOptionWS:
+ if i+3 > limit || opts[i+1] != 3 {
+ return synOpts
+ }
+ ws := int(opts[i+2])
+ if ws > MaxWndScale {
+ ws = MaxWndScale
+ }
+ synOpts.WS = ws
+ i += 3
+
+ case TCPOptionTS:
+ if i+10 > limit || opts[i+1] != 10 {
+ return synOpts
+ }
+ synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
+ if isAck {
+ // If the segment is a SYN-ACK then store the Timestamp Echo Reply
+ // in the segment.
+ synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
+ }
+ synOpts.TS = true
+ i += 10
+ case TCPOptionSACKPermitted:
+ if i+2 > limit || opts[i+1] != 2 {
+ return synOpts
+ }
+ synOpts.SACKPermitted = true
+ i += 2
+
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return synOpts
+ }
+ l := int(opts[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return synOpts
+ }
+ i += l
+ }
+ }
+
+ return synOpts
+}
+
+// ParseTCPOptions extracts and stores all known options in the provided byte
+// slice in a TCPOptions structure.
+func ParseTCPOptions(b []byte) TCPOptions {
+ opts := TCPOptions{}
+ limit := len(b)
+ for i := 0; i < limit; {
+ switch b[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionTS:
+ if i+10 > limit || (b[i+1] != 10) {
+ return opts
+ }
+ opts.TS = true
+ opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
+ opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
+ i += 10
+ case TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ sackOptionLen := int(b[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ numBlocks := (sackOptionLen - 2) / 8
+ opts.SACKBlocks = []SACKBlock{}
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(b[i+2+j*8:])
+ end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
+ opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return opts
+ }
+ l := int(b[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return opts
+ }
+ i += l
+ }
+ }
+ return opts
+}
+
+// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
+// the supplied buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeMSSOption(mss uint32, b []byte) int {
+ // mssOptionSize is the number of bytes in a valid MSS option.
+ const mssOptionSize = 4
+
+ if len(b) < mssOptionSize {
+ return 0
+ }
+ b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss)
+ return mssOptionSize
+}
+
+// EncodeWSOption encodes the WS TCP option with the WS value in the
+// provided buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeWSOption(ws int, b []byte) int {
+ if len(b) < 3 {
+ return 0
+ }
+ b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws)
+ return int(b[1])
+}
+
+// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
+// option into the provided buffer. If the buffer is smaller than expected it
+// just returns without encoding anything. It returns the number of bytes
+// written to the provided buffer.
+func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
+ if len(b) < 10 {
+ return 0
+ }
+ b[0], b[1] = TCPOptionTS, 10
+ binary.BigEndian.PutUint32(b[2:], tsVal)
+ binary.BigEndian.PutUint32(b[6:], tsEcr)
+ return int(b[1])
+}
+
+// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
+// buffer. If the buffer is smaller than required it just returns without
+// encoding anything. It returns the number of bytes written to the provided
+// buffer.
+func EncodeSACKPermittedOption(b []byte) int {
+ if len(b) < 2 {
+ return 0
+ }
+
+ b[0], b[1] = TCPOptionSACKPermitted, 2
+ return int(b[1])
+}
+
+// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
+// in the provided slice. It tries to fit in as many blocks as possible based on
+// number of bytes available in the provided buffer. It returns the number of
+// bytes written to the provided buffer.
+func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
+ if len(sackBlocks) == 0 {
+ return 0
+ }
+ l := len(sackBlocks)
+ if l > TCPMaxSACKBlocks {
+ l = TCPMaxSACKBlocks
+ }
+ if ll := (len(b) - 2) / 8; ll < l {
+ l = ll
+ }
+ if l == 0 {
+ // There is not enough space in the provided buffer to add
+ // any SACK blocks.
+ return 0
+ }
+ b[0] = TCPOptionSACK
+ b[1] = byte(l*8 + 2)
+ for i := 0; i < l; i++ {
+ binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
+ binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
+ }
+ return int(b[1])
+}
+
+// EncodeNOP adds an explicit NOP to the option list.
+func EncodeNOP(b []byte) int {
+ if len(b) == 0 {
+ return 0
+ }
+ b[0] = TCPOptionNOP
+ return 1
+}
+
+// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
+// the option buffer. It adds padding bytes after the offset specified and
+// returns the number of padding bytes added. The passed in options slice
+// must have space for the padding bytes.
+func AddTCPOptionPadding(options []byte, offset int) int {
+ paddingToAdd := -offset & 3
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ options[i] = TCPOptionNOP
+ }
+ return paddingToAdd
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
new file mode 100644
index 000000000..2205fec18
--- /dev/null
+++ b/pkg/tcpip/header/udp.go
@@ -0,0 +1,110 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ udpSrcPort = 0
+ udpDstPort = 2
+ udpLength = 4
+ udpChecksum = 6
+)
+
+// UDPFields contains the fields of a UDP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type UDPFields struct {
+ // SrcPort is the "source port" field of a UDP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a UDP packet.
+ DstPort uint16
+
+ // Length is the "length" field of a UDP packet.
+ Length uint16
+
+ // Checksum is the "checksum" field of a UDP packet.
+ Checksum uint16
+}
+
+// UDP represents a UDP header stored in a byte array.
+type UDP []byte
+
+const (
+ // UDPMinimumSize is the minimum size of a valid UDP packet.
+ UDPMinimumSize = 8
+
+ // UDPProtocolNumber is UDP's transport protocol number.
+ UDPProtocolNumber tcpip.TransportProtocolNumber = 17
+)
+
+// SourcePort returns the "source port" field of the udp header.
+func (b UDP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[udpSrcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the udp header.
+func (b UDP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[udpDstPort:])
+}
+
+// Length returns the "length" field of the udp header.
+func (b UDP) Length() uint16 {
+ return binary.BigEndian.Uint16(b[udpLength:])
+}
+
+// Payload returns the data contained in the UDP datagram.
+func (b UDP) Payload() []byte {
+ return b[UDPMinimumSize:]
+}
+
+// Checksum returns the "checksum" field of the udp header.
+func (b UDP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[udpChecksum:])
+}
+
+// SetSourcePort sets the "source port" field of the udp header.
+func (b UDP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the udp header.
+func (b UDP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpDstPort:], port)
+}
+
+// SetChecksum sets the "checksum" field of the udp header.
+func (b UDP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
+}
+
+// CalculateChecksum calculates the checksum of the udp packet, given the
+// checksum of the network-layer pseudo-header and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
+ // Calculate the rest of the checksum.
+ return Checksum(b[:UDPMinimumSize], partialChecksum)
+}
+
+// Encode encodes all the fields of the udp header.
+func (b UDP) Encode(u *UDPFields) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
+ binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
+ binary.BigEndian.PutUint16(b[udpLength:], u.Length)
+ binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
new file mode 100644
index 000000000..1f889c2a0
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -0,0 +1,372 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package fdbased provides the implemention of data-link layer endpoints
+// backed by boundary-preserving file descriptors (e.g., TUN devices,
+// seqpacket/datagram sockets).
+//
+// FD based endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package fdbased
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// linkDispatcher reads packets from the link FD and dispatches them to the
+// NetworkDispatcher.
+type linkDispatcher interface {
+ dispatch() (bool, *tcpip.Error)
+}
+
+// PacketDispatchMode are the various supported methods of receiving and
+// dispatching packets from the underlying FD.
+type PacketDispatchMode int
+
+const (
+ // Readv is the default dispatch mode and is the least performant of the
+ // dispatch options but the one that is supported by all underlying FD
+ // types.
+ Readv PacketDispatchMode = iota
+ // RecvMMsg enables use of recvmmsg() syscall instead of readv() to
+ // read inbound packets. This reduces # of syscalls needed to process
+ // packets.
+ //
+ // NOTE: recvmmsg() is only supported for sockets, so if the underlying
+ // FD is not a socket then the code will still fall back to the readv()
+ // path.
+ RecvMMsg
+ // PacketMMap enables use of PACKET_RX_RING to receive packets from the
+ // NIC. PacketMMap requires that the underlying FD be an AF_PACKET. The
+ // primary use-case for this is runsc which uses an AF_PACKET FD to
+ // receive packets from the veth device.
+ PacketMMap
+)
+
+type endpoint struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ mtu uint32
+
+ // hdrSize specifies the link-layer header size. If set to 0, no header
+ // is added/removed; otherwise an ethernet header is used.
+ hdrSize int
+
+ // addr is the address of the endpoint.
+ addr tcpip.LinkAddress
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // closed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ closed func(*tcpip.Error)
+
+ inboundDispatcher linkDispatcher
+ dispatcher stack.NetworkDispatcher
+
+ // packetDispatchMode controls the packet dispatcher used by this
+ // endpoint.
+ packetDispatchMode PacketDispatchMode
+
+ // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
+ // disabled.
+ gsoMaxSize uint32
+}
+
+// Options specify the details about the fd-based endpoint to be created.
+type Options struct {
+ FD int
+ MTU uint32
+ EthernetHeader bool
+ ClosedFunc func(*tcpip.Error)
+ Address tcpip.LinkAddress
+ SaveRestore bool
+ DisconnectOk bool
+ GSOMaxSize uint32
+ PacketDispatchMode PacketDispatchMode
+ TXChecksumOffload bool
+ RXChecksumOffload bool
+}
+
+// New creates a new fd-based endpoint.
+//
+// Makes fd non-blocking, but does not take ownership of fd, which must remain
+// open for the lifetime of the returned endpoint.
+func New(opts *Options) (tcpip.LinkEndpointID, error) {
+ if err := syscall.SetNonblock(opts.FD, true); err != nil {
+ return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", opts.FD, err)
+ }
+
+ caps := stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ hdrSize := 0
+ if opts.EthernetHeader {
+ hdrSize = header.EthernetMinimumSize
+ caps |= stack.CapabilityResolutionRequired
+ }
+
+ if opts.SaveRestore {
+ caps |= stack.CapabilitySaveRestore
+ }
+
+ if opts.DisconnectOk {
+ caps |= stack.CapabilityDisconnectOk
+ }
+
+ e := &endpoint{
+ fd: opts.FD,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ }
+
+ isSocket, err := isSocketFD(e.fd)
+ if err != nil {
+ return 0, err
+ }
+ if isSocket {
+ if opts.GSOMaxSize != 0 {
+ e.caps |= stack.CapabilityGSO
+ e.gsoMaxSize = opts.GSOMaxSize
+ }
+ }
+ e.inboundDispatcher, err = createInboundDispatcher(e, isSocket)
+ if err != nil {
+ return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ }
+
+ return stack.RegisterLinkEndpoint(e), nil
+}
+
+func createInboundDispatcher(e *endpoint, isSocket bool) (linkDispatcher, error) {
+ // By default use the readv() dispatcher as it works with all kinds of
+ // FDs (tap/tun/unix domain sockets and af_packet).
+ inboundDispatcher, err := newReadVDispatcher(e.fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", e.fd, e, err)
+ }
+
+ if isSocket {
+ switch e.packetDispatchMode {
+ case PacketMMap:
+ inboundDispatcher, err = newPacketMMapDispatcher(e.fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", e.fd, e, err)
+ }
+ case RecvMMsg:
+ // If the provided FD is a socket then we optimize
+ // packet reads by using recvmmsg() instead of read() to
+ // read packets in a batch.
+ inboundDispatcher, err = newRecvMMsgDispatcher(e.fd, e)
+ if err != nil {
+ return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", e.fd, e, err)
+ }
+ }
+ }
+ return inboundDispatcher, nil
+}
+
+func isSocketFD(fd int) (bool, error) {
+ var stat syscall.Stat_t
+ if err := syscall.Fstat(fd, &stat); err != nil {
+ return false, fmt.Errorf("syscall.Fstat(%v,...) failed: %v", fd, err)
+ }
+ return (stat.Mode & syscall.S_IFSOCK) == syscall.S_IFSOCK, nil
+}
+
+// Attach launches the goroutine that reads packets from the file descriptor and
+// dispatches them via the provided dispatcher.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ go e.dispatchLoop() // S/R-SAFE: See above.
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *endpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength returns the maximum size of the link-layer header.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// virtioNetHdr is declared in linux/virtio_net.h.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+// These constants are declared in linux/virtio_net.h.
+const (
+ _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
+
+ _VIRTIO_NET_HDR_GSO_TCPV4 = 1
+ _VIRTIO_NET_HDR_GSO_TCPV6 = 4
+)
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ }
+
+ if e.Capabilities()&stack.CapabilityGSO != 0 {
+ vnetHdr := virtioNetHdr{}
+ vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
+ if gso != nil {
+ vnetHdr.hdrLen = uint16(hdr.UsedLength())
+ if gso.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
+ vnetHdr.csumOffset = gso.CsumOffset
+ }
+ if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS {
+ switch gso.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", gso.Type))
+ }
+ vnetHdr.gsoSize = gso.MSS
+ }
+ }
+
+ return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView())
+ }
+
+ if payload.Size() == 0 {
+ return rawfile.NonBlockingWrite(e.fd, hdr.View())
+ }
+
+ return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil)
+}
+
+// WriteRawPacket writes a raw packet directly to the file descriptor.
+func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fd, packet)
+}
+
+// dispatchLoop reads packets from the file descriptor in a loop and dispatches
+// them to the network stack.
+func (e *endpoint) dispatchLoop() *tcpip.Error {
+ for {
+ cont, err := e.inboundDispatcher.dispatch()
+ if err != nil || !cont {
+ if e.closed != nil {
+ e.closed(err)
+ }
+ return err
+ }
+ }
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ return e.gsoMaxSize
+}
+
+// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
+// to the FD, but does not read from it. All reads come from injected packets.
+type InjectableEndpoint struct {
+ endpoint
+
+ dispatcher stack.NetworkDispatcher
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// Inject injects an inbound packet.
+func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+}
+
+// NewInjectable creates a new fd-based InjectableEndpoint.
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+ syscall.SetNonblock(fd, true)
+
+ e := &InjectableEndpoint{endpoint: endpoint{
+ fd: fd,
+ mtu: mtu,
+ caps: capabilities,
+ }}
+
+ return stack.RegisterLinkEndpoint(e), e
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
new file mode 100644
index 000000000..97a477b61
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{}))
+
+func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ sh.Data = uintptr(unsafe.Pointer(hdr))
+ sh.Len = virtioNetHdrSize
+ sh.Cap = virtioNetHdrSize
+ return
+}
diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
new file mode 100755
index 000000000..0555db528
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package fdbased
+
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
new file mode 100644
index 000000000..6b7f2a185
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64
+
+package fdbased
+
+import "gvisor.googlesource.com/gvisor/pkg/tcpip"
+
+// Stubbed out version for non-linux/non-amd64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, *tcpip.Error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
new file mode 100644
index 000000000..1c2d8c468
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_amd64.go
@@ -0,0 +1,194 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+
+package fdbased
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, -1); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
new file mode 100644
index 000000000..47cb1d1cc
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
@@ -0,0 +1,84 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+
+package fdbased
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// tPacketHdrlen is the TPACKET_HDRLEN variable defined in <linux/if_packet.h>.
+var tPacketHdrlen = tPacketAlign(unsafe.Sizeof(tPacketHdr{}) + unsafe.Sizeof(syscall.RawSockaddrLinklayer{}))
+
+// tpStatus returns the frame status field.
+// The status is concurrently updated by the kernel as a result we must
+// use atomic operations to prevent races.
+func (t tPacketHdr) tpStatus() uint32 {
+ hdr := unsafe.Pointer(&t[0])
+ statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset))
+ return atomic.LoadUint32((*uint32)(statusPtr))
+}
+
+// setTPStatus set's the frame status to the provided status.
+// The status is concurrently updated by the kernel as a result we must
+// use atomic operations to prevent races.
+func (t tPacketHdr) setTPStatus(status uint32) {
+ hdr := unsafe.Pointer(&t[0])
+ statusPtr := unsafe.Pointer(uintptr(hdr) + uintptr(tpStatusOffset))
+ atomic.StoreUint32((*uint32)(statusPtr), status)
+}
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &packetMMapDispatcher{
+ fd: fd,
+ e: e,
+ }
+ pageSize := unix.Getpagesize()
+ if tpBlockSize%pageSize != 0 {
+ return nil, fmt.Errorf("tpBlockSize: %d is not page aligned, pagesize: %d", tpBlockSize, pageSize)
+ }
+ tReq := tPacketReq{
+ tpBlockSize: uint32(tpBlockSize),
+ tpBlockNR: uint32(tpBlockNR),
+ tpFrameSize: uint32(tpFrameSize),
+ tpFrameNR: uint32(tpFrameNR),
+ }
+ // Setup PACKET_RX_RING.
+ if err := setsockopt(d.fd, syscall.SOL_PACKET, syscall.PACKET_RX_RING, unsafe.Pointer(&tReq), unsafe.Sizeof(tReq)); err != nil {
+ return nil, fmt.Errorf("failed to enable PACKET_RX_RING: %v", err)
+ }
+ // Let's mmap the blocks.
+ sz := tpBlockSize * tpBlockNR
+ buf, err := syscall.Mmap(d.fd, 0, sz, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
+ if err != nil {
+ return nil, fmt.Errorf("syscall.Mmap(...,0, %v, ...) failed = %v", sz, err)
+ }
+ d.ringBuffer = buf
+ return d, nil
+}
+
+func setsockopt(fd, level, name int, val unsafe.Pointer, vallen uintptr) error {
+ if _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(val), vallen, 0); errno != 0 {
+ return error(errno)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
new file mode 100644
index 000000000..1ae0e3359
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -0,0 +1,309 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package fdbased
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
+var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
+
+// readVDispatcher uses readv() system call to read inbound packets and
+// dispatches them.
+type readVDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // views are the actual buffers that hold the packet contents.
+ views []buffer.View
+
+ // iovecs are initialized with base pointers/len of the corresponding
+ // entries in the views defined above, except when GSO is enabled then
+ // the first iovec points to a buffer for the vnet header which is
+ // stripped before the views are passed up the stack for further
+ // processing.
+ iovecs []syscall.Iovec
+}
+
+func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &readVDispatcher{fd: fd, e: e}
+ d.views = make([]buffer.View, len(BufConfig))
+ iovLen := len(BufConfig)
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ iovLen++
+ }
+ d.iovecs = make([]syscall.Iovec, iovLen)
+ return d, nil
+}
+
+func (d *readVDispatcher) allocateViews(bufConfig []int) {
+ var vnetHdr [virtioNetHdrSize]byte
+ vnetHdrOff := 0
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // The kernel adds virtioNetHdr before each packet, but
+ // we don't use it, so so we allocate a buffer for it,
+ // add it in iovecs but don't add it in a view.
+ d.iovecs[0] = syscall.Iovec{
+ Base: &vnetHdr[0],
+ Len: uint64(virtioNetHdrSize),
+ }
+ vnetHdrOff++
+ }
+ for i := 0; i < len(bufConfig); i++ {
+ if d.views[i] != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ d.views[i] = b
+ d.iovecs[i+vnetHdrOff] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+}
+
+func (d *readVDispatcher) capViews(n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ d.views[i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+// dispatch reads one packet from the file descriptor and dispatches it.
+func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
+ d.allocateViews(BufConfig)
+
+ n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
+ if err != nil {
+ return false, err
+ }
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // Skip virtioNetHdr which is added before each packet, it
+ // isn't used and it isn't in a view.
+ n -= virtioNetHdrSize
+ }
+ if n <= d.e.hdrSize {
+ return false, nil
+ }
+
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(d.views[0])
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(d.views[0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := d.capViews(n, BufConfig)
+ vv := buffer.NewVectorisedView(n, d.views[:used])
+ vv.TrimFront(d.e.hdrSize)
+
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ d.views[i] = nil
+ }
+
+ return true, nil
+}
+
+// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
+// dispatches them.
+type recvMMsgDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // views is an array of array of buffers that contain packet contents.
+ views [][]buffer.View
+
+ // iovecs is an array of array of iovec records where each iovec base
+ // pointer and length are initialzed to the corresponding view above,
+ // except when GSO is neabled then the first iovec in each array of
+ // iovecs points to a buffer for the vnet header which is stripped
+ // before the views are passed up the stack for further processing.
+ iovecs [][]syscall.Iovec
+
+ // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
+ // reference an array of iovecs in the iovecs field defined above. This
+ // array is passed as the parameter to recvmmsg call to retrieve
+ // potentially more than 1 packet per syscall.
+ msgHdrs []rawfile.MMsgHdr
+}
+
+const (
+ // MaxMsgsPerRecv is the maximum number of packets we want to retrieve
+ // in a single RecvMMsg call.
+ MaxMsgsPerRecv = 8
+)
+
+func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &recvMMsgDispatcher{
+ fd: fd,
+ e: e,
+ }
+ d.views = make([][]buffer.View, MaxMsgsPerRecv)
+ for i := range d.views {
+ d.views[i] = make([]buffer.View, len(BufConfig))
+ }
+ d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
+ iovLen := len(BufConfig)
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // virtioNetHdr is prepended before each packet.
+ iovLen++
+ }
+ for i := range d.iovecs {
+ d.iovecs[i] = make([]syscall.Iovec, iovLen)
+ }
+ d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv)
+ for i := range d.msgHdrs {
+ d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0]
+ d.msgHdrs[i].Msg.Iovlen = uint64(iovLen)
+ }
+ return d, nil
+}
+
+func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ d.views[k][i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
+ for k := 0; k < len(d.views); k++ {
+ var vnetHdr [virtioNetHdrSize]byte
+ vnetHdrOff := 0
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ // The kernel adds virtioNetHdr before each packet, but
+ // we don't use it, so so we allocate a buffer for it,
+ // add it in iovecs but don't add it in a view.
+ d.iovecs[k][0] = syscall.Iovec{
+ Base: &vnetHdr[0],
+ Len: uint64(virtioNetHdrSize),
+ }
+ vnetHdrOff++
+ }
+ for i := 0; i < len(bufConfig); i++ {
+ if d.views[k][i] != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ d.views[k][i] = b
+ d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+ }
+}
+
+// recvMMsgDispatch reads more than one packet at a time from the file
+// descriptor and dispatches it.
+func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
+ d.allocateViews(BufConfig)
+
+ nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
+ if err != nil {
+ return false, err
+ }
+ // Process each of received packets.
+ for k := 0; k < nMsgs; k++ {
+ n := int(d.msgHdrs[k].Len)
+ if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ n -= virtioNetHdrSize
+ }
+ if n <= d.e.hdrSize {
+ return false, nil
+ }
+
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(d.views[k][0])
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(d.views[k][0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := d.capViews(k, int(n), BufConfig)
+ vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
+ vv.TrimFront(d.e.hdrSize)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ d.views[k][i] = nil
+ }
+ }
+
+ for k := 0; k < nMsgs; k++ {
+ d.msgHdrs[k].Len = 0
+ }
+
+ return true, nil
+}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
new file mode 100644
index 000000000..2c1148123
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -0,0 +1,87 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package loopback provides the implemention of loopback data-link layer
+// endpoints. Such endpoints just turn outbound packets into inbound ones.
+//
+// Loopback endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package loopback
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+}
+
+// New creates a new loopback endpoint. This link-layer endpoint just turns
+// outbound packets into inbound packets.
+func New() tcpip.LinkEndpointID {
+ return stack.RegisterLinkEndpoint(&endpoint{})
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
+// layer dispatcher for later use when packets need to be dispatched.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the
+// linux loopback interface.
+func (*endpoint) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises
+// itself as supporting checksum offload, but in reality it's just omitted.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityRXChecksumOffload | stack.CapabilityTXChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the
+// loopback interface doesn't have a header, it just returns 0.
+func (*endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*endpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
+// packets to the network-layer dispatcher.
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+
+ // Because we're immediately turning around and writing the packet back to the
+ // rx path, we intentionally don't preserve the remote and local link
+ // addresses from the stack.Route we're passed.
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go
new file mode 100755
index 000000000..87ec8cfc7
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package loopback
+
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
new file mode 100644
index 000000000..b54131573
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the poll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ CALL ·callEntersyscallblock(SB)
+ MOVQ fds+0(FP), DI
+ MOVQ nfds+8(FP), SI
+ MOVQ timeout+16(FP), DX
+ MOVQ $0x7, AX // SYS_POLL
+ SYSCALL
+ CMPQ AX, $0xfffffffffffff001
+ JLS ok
+ MOVQ $-1, n+24(FP)
+ NEGQ AX
+ MOVQ AX, err+32(FP)
+ CALL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVQ AX, n+24(FP)
+ MOVQ $0, err+32(FP)
+ CALL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
new file mode 100644
index 000000000..c87268610
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
@@ -0,0 +1,60 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64
+// +build go1.12
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package rawfile
+
+import (
+ "syscall"
+ _ "unsafe" // for go:linkname
+)
+
+//go:noescape
+func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno)
+
+// Use go:linkname to call into the runtime. As of Go 1.12 this has to
+// be done from Go code so that we make an ABIInternal call to an
+// ABIInternal function; see https://golang.org/issue/27539.
+
+// We need to call both entersyscallblock and exitsyscall this way so
+// that the runtime's check on the stack pointer lines up.
+
+// Note that calling an unexported function in the runtime package is
+// unsafe and this hack is likely to break in future Go releases.
+
+//go:linkname entersyscallblock runtime.entersyscallblock
+func entersyscallblock()
+
+//go:linkname exitsyscall runtime.exitsyscall
+func exitsyscall()
+
+// These forwarding functions must be nosplit because 1) we must
+// disallow preemption between entersyscallblock and exitsyscall, and
+// 2) we have an untyped assembly frame on the stack which can not be
+// grown or moved.
+
+//go:nosplit
+func callEntersyscallblock() {
+ entersyscallblock()
+}
+
+//go:nosplit
+func callExitsyscall() {
+ exitsyscall()
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
new file mode 100644
index 000000000..4eab77c74
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// BlockingPoll is just a stub function that forwards to the poll() system call
+// on non-amd64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout))
+ return int(n), e
+}
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
new file mode 100644
index 000000000..8bde41637
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const maxErrno = 134
+
+var translations [maxErrno]*tcpip.Error
+
+// TranslateErrno translate an errno from the syscall package into a
+// *tcpip.Error.
+//
+// Valid, but unreconigized errnos will be translated to
+// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+func TranslateErrno(e syscall.Errno) *tcpip.Error {
+ if err := translations[e]; err != nil {
+ return err
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+func addTranslation(host syscall.Errno, trans *tcpip.Error) {
+ if translations[host] != nil {
+ panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+ }
+ translations[host] = trans
+}
+
+func init() {
+ addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress)
+ addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute)
+ addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState)
+ addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting)
+ addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected)
+ addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse)
+ addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress)
+ addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend)
+ addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock)
+ addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused)
+ addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout)
+ addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted)
+ addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired)
+ addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported)
+ addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported)
+ addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected)
+ addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset)
+ addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted)
+ addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong)
+ addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace)
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
new file mode 100755
index 000000000..662c04444
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package rawfile
+
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
new file mode 100644
index 000000000..86db7a487
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -0,0 +1,182 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package rawfile contains utilities for using the netstack with raw host
+// files on Linux hosts.
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// GetMTU determines the MTU of a network interface device.
+func GetMTU(name string) (uint32, error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return 0, err
+ }
+
+ defer syscall.Close(fd)
+
+ var ifreq struct {
+ name [16]byte
+ mtu int32
+ _ [20]byte
+ }
+
+ copy(ifreq.name[:], name)
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
+ if errno != 0 {
+ return 0, errno
+ }
+
+ return uint32(ifreq.mtu), nil
+}
+
+// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
+// partial data is written.
+func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
+ var ptr unsafe.Pointer
+ if len(buf) > 0 {
+ ptr = unsafe.Pointer(&buf[0])
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
+// single syscall. It fails if partial data is written.
+func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
+ // If the is no second buffer, issue a regular write.
+ if len(b2) == 0 {
+ return NonBlockingWrite(fd, b1)
+ }
+
+ // We have two buffers. Build the iovec that represents them and issue
+ // a writev syscall.
+ iovec := [3]syscall.Iovec{
+ {
+ Base: &b1[0],
+ Len: uint64(len(b1)),
+ },
+ {
+ Base: &b2[0],
+ Len: uint64(len(b2)),
+ },
+ }
+ iovecLen := uintptr(2)
+
+ if len(b3) > 0 {
+ iovecLen++
+ iovec[2].Base = &b3[0]
+ iovec[2].Len = uint64(len(b3))
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// PollEvent represents the pollfd structure passed to a poll() system call.
+type PollEvent struct {
+ FD int32
+ Events int16
+ Revents int16
+}
+
+// BlockingRead reads from a file descriptor that is set up as non-blocking. If
+// no data is available, it will block in a poll() syscall until the file
+// descirptor becomes readable.
+func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ _, e = BlockingPoll(&event, 1, -1)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// BlockingReadv reads from a file descriptor that is set up as non-blocking and
+// stores the data in a list of iovecs buffers. If no data is available, it will
+// block in a poll() syscall until the file descriptor becomes readable.
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ _, e = BlockingPoll(&event, 1, -1)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg syscall.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
+// and stores the received messages in a slice of MMsgHdr structures. If no data
+// is available, it will block in a poll() syscall until the file descriptor
+// becomes readable.
+func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := PollEvent{
+ FD: int32(fd),
+ Events: 1, // POLLIN
+ }
+
+ if _, e := BlockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
new file mode 100644
index 000000000..c16c19647
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sniffer
+
+import "time"
+
+type pcapHeader struct {
+ // MagicNumber is the file magic number.
+ MagicNumber uint32
+
+ // VersionMajor is the major version number.
+ VersionMajor uint16
+
+ // VersionMinor is the minor version number.
+ VersionMinor uint16
+
+ // Thiszone is the GMT to local correction.
+ Thiszone int32
+
+ // Sigfigs is the accuracy of timestamps.
+ Sigfigs uint32
+
+ // Snaplen is the max length of captured packets, in octets.
+ Snaplen uint32
+
+ // Network is the data link type.
+ Network uint32
+}
+
+const pcapPacketHeaderLen = 16
+
+type pcapPacketHeader struct {
+ // Seconds is the timestamp seconds.
+ Seconds uint32
+
+ // Microseconds is the timestamp microseconds.
+ Microseconds uint32
+
+ // IncludedLength is the number of octets of packet saved in file.
+ IncludedLength uint32
+
+ // OriginalLength is the actual length of packet.
+ OriginalLength uint32
+}
+
+func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
+ now := time.Now()
+ return pcapPacketHeader{
+ Seconds: uint32(now.Unix()),
+ Microseconds: uint32(now.Nanosecond() / 1000),
+ IncludedLength: incLen,
+ OriginalLength: orgLen,
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
new file mode 100644
index 000000000..fccabd554
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -0,0 +1,408 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sniffer provides the implementation of data-link layer endpoints that
+// wrap another endpoint and logs inbound and outbound packets.
+//
+// Sniffer endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package sniffer
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "os"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// LogPackets is a flag used to enable or disable packet logging via the log
+// package. Valid values are 0 or 1.
+//
+// LogPackets must be accessed atomically.
+var LogPackets uint32 = 1
+
+// LogPacketsToFile is a flag used to enable or disable logging packets to a
+// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// sniffer was created for this flag to have effect.
+//
+// LogPacketsToFile must be accessed atomically.
+var LogPacketsToFile uint32 = 1
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ lower stack.LinkEndpoint
+ file *os.File
+ maxPCAPLen uint32
+}
+
+// New creates a new sniffer link-layer endpoint. It wraps around another
+// endpoint and logs packets and they traverse the endpoint.
+func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ })
+}
+
+func zoneOffset() (int32, error) {
+ loc, err := time.LoadLocation("Local")
+ if err != nil {
+ return 0, err
+ }
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ _, offset := date.Zone()
+ return int32(offset), nil
+}
+
+func writePCAPHeader(w io.Writer, maxLen uint32) error {
+ offset, err := zoneOffset()
+ if err != nil {
+ return err
+ }
+ return binary.Write(w, binary.BigEndian, pcapHeader{
+ // From https://wiki.wireshark.org/Development/LibpcapFileFormat
+ MagicNumber: 0xa1b2c3d4,
+
+ VersionMajor: 2,
+ VersionMinor: 4,
+ Thiszone: offset,
+ Sigfigs: 0,
+ Snaplen: maxLen,
+ Network: 101, // LINKTYPE_RAW
+ })
+}
+
+// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets and they traverse the endpoint.
+//
+// Packets can be logged to file in the pcap format. A sniffer created
+// with this function will not emit packets using the standard log
+// package.
+//
+// snapLen is the maximum amount of a packet to be saved. Packets with a length
+// less than or equal too snapLen will be saved in their entirety. Longer
+// packets will be truncated to snapLen.
+func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+ if err := writePCAPHeader(file, snapLen); err != nil {
+ return 0, err
+ }
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ file: file,
+ maxPCAPLen: snapLen,
+ }), nil
+}
+
+// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
+// called by the link-layer endpoint being wrapped when a packet arrives, and
+// logs the packet before forwarding to the actual dispatcher.
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("recv", protocol, vv.First())
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ vs := vv.Views()
+ length := vv.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
+ panic(err)
+ }
+ for _, v := range vs {
+ if length == 0 {
+ break
+ }
+ if len(v) > length {
+ v = v[:length]
+ }
+ if _, err := buf.Write([]byte(v)); err != nil {
+ panic(err)
+ }
+ length -= len(v)
+ }
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+}
+
+// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
+// and registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
+// the request to the lower endpoint.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and forwards
+// the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("send", protocol, hdr.View())
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ hdrBuf := hdr.View()
+ length := len(hdrBuf) + payload.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil {
+ panic(err)
+ }
+ if len(hdrBuf) > length {
+ hdrBuf = hdrBuf[:length]
+ }
+ if _, err := buf.Write(hdrBuf); err != nil {
+ panic(err)
+ }
+ length -= len(hdrBuf)
+ if length > 0 {
+ for _, v := range payload.Views() {
+ if len(v) > length {
+ v = v[:length]
+ }
+ n, err := buf.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ length -= n
+ if length == 0 {
+ break
+ }
+ }
+ }
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ return e.lower.WritePacket(r, gso, hdr, payload, protocol)
+}
+
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) {
+ // Figure out the network layer info.
+ var transProto uint8
+ src := tcpip.Address("unknown")
+ dst := tcpip.Address("unknown")
+ id := 0
+ size := uint16(0)
+ var fragmentOffset uint16
+ var moreFragments bool
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ ipv4 := header.IPv4(b)
+ fragmentOffset = ipv4.FragmentOffset()
+ moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
+ src = ipv4.SourceAddress()
+ dst = ipv4.DestinationAddress()
+ transProto = ipv4.Protocol()
+ size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
+ b = b[ipv4.HeaderLength():]
+ id = int(ipv4.ID())
+
+ case header.IPv6ProtocolNumber:
+ ipv6 := header.IPv6(b)
+ src = ipv6.SourceAddress()
+ dst = ipv6.DestinationAddress()
+ transProto = ipv6.NextHeader()
+ size = ipv6.PayloadLength()
+ b = b[header.IPv6MinimumSize:]
+
+ case header.ARPProtocolNumber:
+ arp := header.ARP(b)
+ log.Infof(
+ "%s arp %v (%v) -> %v (%v) valid:%v",
+ prefix,
+ tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
+ tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
+ arp.IsValid(),
+ )
+ return
+ default:
+ log.Infof("%s unknown network protocol", prefix)
+ return
+ }
+
+ // Figure out the transport layer info.
+ transName := "unknown"
+ srcPort := uint16(0)
+ dstPort := uint16(0)
+ details := ""
+ switch tcpip.TransportProtocolNumber(transProto) {
+ case header.ICMPv4ProtocolNumber:
+ transName = "icmp"
+ icmp := header.ICMPv4(b)
+ icmpType := "unknown"
+ if fragmentOffset == 0 {
+ switch icmp.Type() {
+ case header.ICMPv4EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv4DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv4SrcQuench:
+ icmpType = "source quench"
+ case header.ICMPv4Redirect:
+ icmpType = "redirect"
+ case header.ICMPv4Echo:
+ icmpType = "echo"
+ case header.ICMPv4TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv4ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv4Timestamp:
+ icmpType = "timestamp"
+ case header.ICMPv4TimestampReply:
+ icmpType = "timestamp reply"
+ case header.ICMPv4InfoRequest:
+ icmpType = "info request"
+ case header.ICMPv4InfoReply:
+ icmpType = "info reply"
+ }
+ }
+ log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.ICMPv6ProtocolNumber:
+ transName = "icmp"
+ icmp := header.ICMPv6(b)
+ icmpType := "unknown"
+ switch icmp.Type() {
+ case header.ICMPv6DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv6PacketTooBig:
+ icmpType = "packet too big"
+ case header.ICMPv6TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv6ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv6EchoRequest:
+ icmpType = "echo request"
+ case header.ICMPv6EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv6RouterSolicit:
+ icmpType = "router solicit"
+ case header.ICMPv6RouterAdvert:
+ icmpType = "router advert"
+ case header.ICMPv6NeighborSolicit:
+ icmpType = "neighbor solicit"
+ case header.ICMPv6NeighborAdvert:
+ icmpType = "neighbor advert"
+ case header.ICMPv6RedirectMsg:
+ icmpType = "redirect message"
+ }
+ log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.UDPProtocolNumber:
+ transName = "udp"
+ udp := header.UDP(b)
+ if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
+ srcPort = udp.SourcePort()
+ dstPort = udp.DestinationPort()
+ }
+ size -= header.UDPMinimumSize
+
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+
+ case header.TCPProtocolNumber:
+ transName = "tcp"
+ tcp := header.TCP(b)
+ if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
+ offset := int(tcp.DataOffset())
+ if offset < header.TCPMinimumSize {
+ details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
+ break
+ }
+ if offset > len(tcp) && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
+ break
+ }
+
+ srcPort = tcp.SourcePort()
+ dstPort = tcp.DestinationPort()
+ size -= uint16(offset)
+
+ // Initialize the TCP flags.
+ flags := tcp.Flags()
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if flags&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ if flags&header.TCPFlagSyn != 0 {
+ details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
+ } else {
+ details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
+ }
+ }
+
+ default:
+ log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
+ return
+ }
+
+ log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
new file mode 100755
index 000000000..cfd84a739
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package sniffer
+
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
new file mode 100644
index 000000000..a3f2bce3e
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp.go
@@ -0,0 +1,203 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package arp implements the ARP network protocol. It is used to resolve
+// IPv4 addresses into link-local MAC addresses, and advertises IPv4
+// addresses of its stack with the local network.
+//
+// To use it in the networking stack, pass arp.ProtocolName as one of the
+// network protocols when calling stack.New. Then add an "arp" address to
+// every NIC on the stack that should respond to ARP requests. That is:
+//
+// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
+// // handle err
+// }
+package arp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ARP protocol name.
+ ProtocolName = "arp"
+
+ // ProtocolNumber is the ARP protocol number.
+ ProtocolNumber = header.ARPProtocolNumber
+
+ // ProtocolAddress is the address expected by the ARP endpoint.
+ ProtocolAddress = tcpip.Address("arp")
+)
+
+// endpoint implements stack.NetworkEndpoint.
+type endpoint struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+}
+
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &stack.NetworkEndpointID{ProtocolAddress}
+}
+
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.ARPSize
+}
+
+func (e *endpoint) Close() {}
+
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+ v := vv.First()
+ h := header.ARP(v)
+ if !h.IsValid() {
+ return
+ }
+
+ switch h.Op() {
+ case header.ARPRequest:
+ localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
+ pkt := header.ARP(hdr.Prepend(header.ARPSize))
+ pkt.SetIPv4OverEthernet()
+ pkt.SetOp(header.ARPReply)
+ copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ fallthrough // also fill the cache from requests
+ case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
+ }
+}
+
+// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
+type protocol struct {
+}
+
+func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
+func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.ARP(v)
+ return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+}
+
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addr != ProtocolAddress {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return &endpoint{
+ nicid: nicid,
+ addr: addr,
+ linkEP: sender,
+ linkAddrCache: linkAddrCache,
+ }, nil
+}
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ r := &stack.Route{
+ RemoteLinkAddress: broadcastMAC,
+ }
+
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
+ h := header.ARP(hdr.Prepend(header.ARPSize))
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), linkEP.LinkAddress())
+ copy(h.ProtocolAddressSender(), localAddr)
+ copy(h.ProtocolAddressTarget(), addr)
+
+ return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == header.IPv4Broadcast {
+ return broadcastMAC, true
+ }
+ if header.IsV4MulticastAddress(addr) {
+ // RFC 1112 Host Extensions for IP Multicasting
+ //
+ // 6.4. Extensions to an Ethernet Local Network Module:
+ //
+ // An IP host group address is mapped to an Ethernet multicast
+ // address by placing the low-order 23-bits of the IP address
+ // into the low-order 23 bits of the Ethernet multicast address
+ // 01-00-5E-00-00-00 (hex).
+ return tcpip.LinkAddress([]byte{
+ 0x01,
+ 0x00,
+ 0x5e,
+ addr[header.IPv4AddressSize-3] & 0x7f,
+ addr[header.IPv4AddressSize-2],
+ addr[header.IPv4AddressSize-1],
+ }), true
+ }
+ return "", false
+}
+
+// SetOption implements NetworkProtocol.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go
new file mode 100755
index 000000000..14a21baff
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package arp
+
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
new file mode 100644
index 000000000..9ad3e5a8a
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap.go
@@ -0,0 +1,77 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type fragment struct {
+ offset uint16
+ vv buffer.VectorisedView
+}
+
+type fragHeap []fragment
+
+func (h *fragHeap) Len() int {
+ return len(*h)
+}
+
+func (h *fragHeap) Less(i, j int) bool {
+ return (*h)[i].offset < (*h)[j].offset
+}
+
+func (h *fragHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *fragHeap) Push(x interface{}) {
+ *h = append(*h, x.(fragment))
+}
+
+func (h *fragHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
+
+// reassamble empties the heap and returns a VectorisedView
+// containing a reassambled version of the fragments inside the heap.
+func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
+ curr := heap.Pop(h).(fragment)
+ views := curr.vv.Views()
+ size := curr.vv.Size()
+
+ if curr.offset != 0 {
+ return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ }
+
+ for h.Len() > 0 {
+ curr := heap.Pop(h).(fragment)
+ if int(curr.offset) < size {
+ curr.vv.TrimFront(size - int(curr.offset))
+ } else if int(curr.offset) > size {
+ return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ }
+ size += curr.vv.Size()
+ views = append(views, curr.vv.Views()...)
+ }
+ return buffer.NewVectorisedView(size, views), nil
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
new file mode 100644
index 000000000..e90edb375
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791 and RFC 815.
+package fragmentation
+
+import (
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
+
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[uint32]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
+ reassemblers: make(map[uint32]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ }
+}
+
+// Process processes an incoming fragment beloning to an ID
+// and returns a complete packet when all the packets belonging to that ID have been received.
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if ok && r.tooOld(f.timeout) {
+ // This is very likely to be an id-collision or someone performing a slow-rate attack.
+ f.release(r)
+ ok = false
+ }
+ if !ok {
+ r = newReassembler(id)
+ f.reassemblers[id] = r
+ f.rList.PushFront(r)
+ }
+ f.mu.Unlock()
+
+ res, done, consumed := r.process(first, last, more, vv)
+
+ f.mu.Lock()
+ f.size += consumed
+ if done {
+ f.release(r)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ tail := f.rList.Back()
+ for f.size > f.lowLimit && tail != nil {
+ f.release(tail)
+ tail = tail.Prev()
+ }
+ }
+ f.mu.Unlock()
+ return res, done
+}
+
+func (f *Fragmentation) release(r *reassembler) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.size -= r.size
+ if f.size < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
+ f.size = 0
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go
new file mode 100755
index 000000000..c012e8012
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go
@@ -0,0 +1,38 @@
+// automatically generated by stateify.
+
+package fragmentation
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *reassemblerList) beforeSave() {}
+func (x *reassemblerList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *reassemblerList) afterLoad() {}
+func (x *reassemblerList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *reassemblerEntry) beforeSave() {}
+func (x *reassemblerEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *reassemblerEntry) afterLoad() {}
+func (x *reassemblerEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("fragmentation.reassemblerList", (*reassemblerList)(nil), state.Fns{Save: (*reassemblerList).save, Load: (*reassemblerList).load})
+ state.Register("fragmentation.reassemblerEntry", (*reassemblerEntry)(nil), state.Fns{Save: (*reassemblerEntry).save, Load: (*reassemblerEntry).load})
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
new file mode 100644
index 000000000..04f9ab964
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ deleted bool
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id uint32
+ size int
+ mu sync.Mutex
+ holes []hole
+ deleted int
+ heap fragHeap
+ done bool
+ creationTime time.Time
+}
+
+func newReassembler(id uint32) *reassembler {
+ r := &reassembler{
+ id: id,
+ holes: make([]hole, 0, 16),
+ deleted: 0,
+ heap: make(fragHeap, 0, 8),
+ creationTime: time.Now(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ deleted: false})
+ return r
+}
+
+// updateHoles updates the list of holes for an incoming fragment and
+// returns true iff the fragment filled at least part of an existing hole.
+func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
+ used := false
+ for i := range r.holes {
+ if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ continue
+ }
+ used = true
+ r.deleted++
+ r.holes[i].deleted = true
+ if first > r.holes[i].first {
+ r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+ }
+ if last < r.holes[i].last && more {
+ r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ }
+ }
+ return used
+}
+
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ consumed := 0
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.VectorisedView{}, false, consumed
+ }
+ if r.updateHoles(first, last, more) {
+ // We store the incoming packet only if it filled some holes.
+ heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
+ consumed = vv.Size()
+ r.size += consumed
+ }
+ // Check if all the holes have been deleted and we are ready to reassamble.
+ if r.deleted < len(r.holes) {
+ return buffer.VectorisedView{}, false, consumed
+ }
+ res, err := r.heap.reassemble()
+ if err != nil {
+ panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ }
+ return res, true, consumed
+}
+
+func (r *reassembler) tooOld(timeout time.Duration) bool {
+ return time.Now().Sub(r.creationTime) > timeout
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go
new file mode 100755
index 000000000..3189cae29
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_list.go
@@ -0,0 +1,173 @@
+package fragmentation
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type reassemblerElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type reassemblerList struct {
+ head *reassembler
+ tail *reassembler
+}
+
+// Reset resets list l to the empty state.
+func (l *reassemblerList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *reassemblerList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *reassemblerList) Front() *reassembler {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *reassemblerList) Back() *reassembler {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *reassemblerList) PushFront(e *reassembler) {
+ reassemblerElementMapper{}.linkerFor(e).SetNext(l.head)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *reassemblerList) PushBack(e *reassembler) {
+ reassemblerElementMapper{}.linkerFor(e).SetNext(nil)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *reassemblerList) PushBackList(m *reassemblerList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *reassemblerList) InsertAfter(b, e *reassembler) {
+ a := reassemblerElementMapper{}.linkerFor(b).Next()
+ reassemblerElementMapper{}.linkerFor(e).SetNext(a)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *reassemblerList) InsertBefore(a, e *reassembler) {
+ b := reassemblerElementMapper{}.linkerFor(a).Prev()
+ reassemblerElementMapper{}.linkerFor(e).SetNext(a)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *reassemblerList) Remove(e *reassembler) {
+ prev := reassemblerElementMapper{}.linkerFor(e).Prev()
+ next := reassemblerElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type reassemblerEntry struct {
+ next *reassembler
+ prev *reassembler
+}
+
+// Next returns the entry that follows e in the list.
+func (e *reassemblerEntry) Next() *reassembler {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *reassemblerEntry) Prev() *reassembler {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *reassemblerEntry) SetNext(elem *reassembler) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *reassemblerEntry) SetPrev(elem *reassembler) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
new file mode 100644
index 000000000..0c91905dc
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hash contains utility functions for hashing.
+package hash
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+var hashIV = RandN32(1)[0]
+
+// RandN32 generates a slice of n cryptographic random 32-bit numbers.
+func RandN32(n int) []uint32 {
+ b := make([]byte, 4*n)
+ if _, err := rand.Read(b); err != nil {
+ panic("unable to get random numbers: " + err.Error())
+ }
+ r := make([]uint32, n)
+ for i := range r {
+ r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
+ }
+ return r
+}
+
+// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
+// from linux.
+func Hash3Words(a, b, c, initval uint32) uint32 {
+ const iv = 0xdeadbeef + (3 << 2)
+ initval += iv
+
+ a += initval
+ b += initval
+ c += initval
+
+ c ^= b
+ c -= rol32(b, 14)
+ a ^= c
+ a -= rol32(c, 11)
+ b ^= a
+ b -= rol32(a, 25)
+ c ^= b
+ c -= rol32(b, 16)
+ a ^= c
+ a -= rol32(c, 4)
+ b ^= a
+ b -= rol32(a, 14)
+ c ^= b
+ c -= rol32(b, 24)
+
+ return c
+}
+
+// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
+func IPv4FragmentHash(h header.IPv4) uint32 {
+ x := uint32(h.ID())<<16 | uint32(h.Protocol())
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(x, y, z, hashIV)
+}
+
+// IPv6FragmentHash computes the hash of the ipv6 fragment.
+// Unlike IPv4, the protocol is not used to compute the hash.
+// RFC 2640 (sec 4.5) is not very sharp on this aspect.
+// As a reference, also Linux ignores the protocol to compute
+// the hash (inet6_hash_frag).
+func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(f.ID(), y, z, hashIV)
+}
+
+func rol32(v, shift uint32) uint32 {
+ return (v << shift) | (v >> ((-shift) & 31))
+}
diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go
new file mode 100755
index 000000000..a3bcd4b69
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package hash
+
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
new file mode 100644
index 000000000..770f56c3d
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -0,0 +1,160 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ vv.TrimFront(hlen)
+ p := h.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ stats := r.Stats()
+ received := stats.ICMP.V4PacketsReceived
+ v := vv.First()
+ if len(v) < header.ICMPv4MinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv4(v)
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ received.Echo.Increment()
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ // Only send a reply if the checksum is valid.
+ wantChecksum := h.Checksum()
+ // Reset the checksum field to 0 to can calculate the proper
+ // checksum. We'll have to reset this before we hand the packet
+ // off.
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */)
+ if gotChecksum != wantChecksum {
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ received.Invalid.Increment()
+ return
+ }
+
+ // It's possible that a raw socket expects to receive this.
+ h.SetChecksum(wantChecksum)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+
+ vv := vv.Clone(nil)
+ vv.TrimFront(header.ICMPv4EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4EchoMinimumSize)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv4EchoReply)
+ pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
+ sent := stats.ICMP.V4PacketsSent
+ if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv4EchoReply:
+ received.EchoReply.Increment()
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+
+ case header.ICMPv4DstUnreachable:
+ received.DstUnreachable.Increment()
+ if len(v) < header.ICMPv4DstUnreachableMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ }
+
+ case header.ICMPv4SrcQuench:
+ received.SrcQuench.Increment()
+
+ case header.ICMPv4Redirect:
+ received.Redirect.Increment()
+
+ case header.ICMPv4TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv4ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv4Timestamp:
+ received.Timestamp.Increment()
+
+ case header.ICMPv4TimestampReply:
+ received.TimestampReply.Increment()
+
+ case header.ICMPv4InfoRequest:
+ received.InfoRequest.Increment()
+
+ case header.ICMPv4InfoReply:
+ received.InfoReply.Increment()
+
+ default:
+ received.Invalid.Increment()
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
new file mode 100644
index 000000000..da07a39e5
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -0,0 +1,344 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv4 contains the implementation of the ipv4 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv4
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv4 protocol name.
+ ProtocolName = "ipv4"
+
+ // ProtocolNumber is the ipv4 protocol number.
+ ProtocolNumber = header.IPv4ProtocolNumber
+
+ // MaxTotalSize is maximum size that can be encoded in the 16-bit
+ // TotalLength field of the ipv4 header.
+ MaxTotalSize = 0xffff
+
+ // buckets is the number of identifier buckets.
+ buckets = 2048
+)
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
+}
+
+// NewEndpoint creates a new ipv4 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ e := &endpoint{
+ nicid: nicid,
+ id: stack.NetworkEndpointID{LocalAddress: addr},
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
+
+ return e, nil
+}
+
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 255
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv4 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
+// write. It assumes that the IP header is entirely in hdr but does not assume
+// that only the IP header is in hdr. It assumes that the input packet's stated
+// length matches the length of the hdr+payload. mtu includes the IP header and
+// options. This does not support the DontFragment IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+ // This packet is too big, it needs to be fragmented.
+ ip := header.IPv4(hdr.View())
+ flags := ip.Flags()
+
+ // Update mtu to take into account the header, which will exist in all
+ // fragments anyway.
+ innerMTU := mtu - int(ip.HeaderLength())
+
+ // Round the MTU down to align to 8 bytes. Then calculate the number of
+ // fragments. Calculate fragment sizes as in RFC791.
+ innerMTU &^= 7
+ n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
+
+ outerMTU := innerMTU + int(ip.HeaderLength())
+ offset := ip.FragmentOffset()
+ originalAvailableLength := hdr.AvailableLength()
+ for i := 0; i < n; i++ {
+ // Where possible, the first fragment that is sent has the same
+ // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
+ // on this for looking at, eg, L4 headers.
+ h := ip
+ if i > 0 {
+ hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
+ copy(h, ip[:ip.HeaderLength()])
+ }
+ if i != n-1 {
+ h.SetTotalLength(uint16(outerMTU))
+ h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
+ } else {
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetFlagsFragmentOffset(flags, offset)
+ }
+ h.SetChecksum(0)
+ h.SetChecksum(^h.CalculateChecksum())
+ offset += uint16(innerMTU)
+ if i > 0 {
+ newPayload := payload.Clone([]buffer.View{})
+ newPayload.CapLength(innerMTU)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayload.Size())
+ continue
+ }
+ // Special handling for the first fragment because it comes from the hdr.
+ if outerMTU >= hdr.UsedLength() {
+ // This fragment can fit all of hdr and possibly some of payload, too.
+ newPayload := payload.Clone([]buffer.View{})
+ newPayloadLength := outerMTU - hdr.UsedLength()
+ newPayload.CapLength(newPayloadLength)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayloadLength)
+ } else {
+ // The fragment is too small to fit all of hdr.
+ startOfHdr := hdr
+ startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
+ emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
+ if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ // Add the unused bytes of hdr into the payload that remains to be sent.
+ restOfHdr := hdr.View()[outerMTU:]
+ tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
+ tmp.Append(payload)
+ payload = tmp
+ }
+ }
+ return nil
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ length := uint16(hdr.UsedLength() + payload.Size())
+ id := uint32(0)
+ if length > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: length,
+ ID: uint16(id),
+ TTL: ttl,
+ Protocol: uint8(protocol),
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+ if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ }
+ if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ return nil
+}
+
+// HandlePacket is called by the link layer when new ipv4 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+ headerView := vv.First()
+ h := header.IPv4(headerView)
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ tlen := int(h.TotalLength())
+ vv.TrimFront(hlen)
+ vv.CapLength(tlen - hlen)
+
+ more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
+ if more || h.FragmentOffset() != 0 {
+ // The packet is a fragment, let's try to reassemble it.
+ last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ var ready bool
+ vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if !ready {
+ return
+ }
+ }
+ p := h.TransportProtocol()
+ if p == header.ICMPv4ProtocolNumber {
+ headerView.CapLength(hlen)
+ e.handleICMP(r, headerView, vv)
+ return
+ }
+ r.Stats().IP.PacketsDelivered.Increment()
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (e *endpoint) Close() {}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv4 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv4 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv4MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv4(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address, the transport protocol number, and a random initial
+// value (generated once on initialization) to generate the hash.
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+ t := r.LocalAddress
+ a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = r.RemoteAddress
+ b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return hash.Hash3Words(a, b, uint32(protocol), hashIV)
+}
+
+var (
+ ids []uint32
+ hashIV uint32
+)
+
+func init() {
+ ids = make([]uint32, buckets)
+
+ // Randomly initialize hashIV and the ids.
+ r := hash.RandN32(1 + buckets)
+ for i := range ids {
+ ids[i] = r[i]
+ }
+ hashIV = r[buckets]
+
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
new file mode 100755
index 000000000..6b2cc0142
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ipv4
+
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
new file mode 100644
index 000000000..9c011e107
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,297 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ vv.TrimFront(header.IPv6MinimumSize)
+ p := h.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f := header.IPv6Fragment(vv.First())
+ if !f.IsValid() || f.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ vv.TrimFront(header.IPv6FragmentHeaderSize)
+ p = f.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ received := stats.V6PacketsReceived
+ v := vv.First()
+ if len(v) < header.ICMPv6MinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.ICMPv6(v)
+
+ // TODO(b/112892170): Meaningfully handle all ICMP types.
+ switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ received.PacketTooBig.Increment()
+ if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+
+ case header.ICMPv6DstUnreachable:
+ received.DstUnreachable.Increment()
+ if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ }
+
+ case header.ICMPv6NeighborSolicit:
+ received.NeighborSolicit.Increment()
+
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+ if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ targetAddr := tcpip.Address(v[8:][:16])
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
+ // We don't have a useful answer; the best we can do is ignore the request.
+ return
+ }
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
+ copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
+ pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
+
+ // ICMPv6 Neighbor Solicit messages are always sent to
+ // specially crafted IPv6 multicast addresses. As a result, the
+ // route we end up with here has as its LocalAddress such a
+ // multicast address. It would be nonsense to claim that our
+ // source address is a multicast address, so we manually set
+ // the source address to the target address requested in the
+ // solicit message. Since that requires mutating the route, we
+ // must first clone it.
+ r := r.Clone()
+ defer r.Release()
+ r.LocalAddress = targetAddr
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.NeighborAdvert.Increment()
+
+ case header.ICMPv6NeighborAdvert:
+ received.NeighborAdvert.Increment()
+ if len(v) < header.ICMPv6NeighborAdvertSize {
+ received.Invalid.Increment()
+ return
+ }
+ targetAddr := tcpip.Address(v[8:][:16])
+ e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
+ if targetAddr != r.RemoteAddress {
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ }
+
+ case header.ICMPv6EchoRequest:
+ received.EchoRequest.Increment()
+ if len(v) < header.ICMPv6EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ vv.TrimFront(header.ICMPv6EchoMinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv6EchoReply)
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
+ if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ sent.Dropped.Increment()
+ return
+ }
+ sent.EchoReply.Increment()
+
+ case header.ICMPv6EchoReply:
+ received.EchoReply.Increment()
+ if len(v) < header.ICMPv6EchoMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
+
+ case header.ICMPv6TimeExceeded:
+ received.TimeExceeded.Increment()
+
+ case header.ICMPv6ParamProblem:
+ received.ParamProblem.Increment()
+
+ case header.ICMPv6RouterSolicit:
+ received.RouterSolicit.Increment()
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
+
+ case header.ICMPv6RedirectMsg:
+ received.RedirectMsg.Increment()
+
+ default:
+ received.Invalid.Increment()
+ }
+}
+
+const (
+ ndpSolicitedFlag = 1 << 6
+ ndpOverrideFlag = 1 << 5
+
+ ndpOptSrcLinkAddr = 1
+ ndpOptDstLinkAddr = 2
+
+ icmpV6FlagOffset = 4
+ icmpV6OptOffset = 24
+ icmpV6LengthOffset = 25
+)
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ snaddr := header.SolicitedNodeAddr(addr)
+ r := &stack.Route{
+ LocalAddress: localAddr,
+ RemoteAddress: snaddr,
+ RemoteLinkAddress: broadcastMAC,
+ }
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ copy(pkt[icmpV6OptOffset-len(addr):], addr)
+ pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ pkt[icmpV6LengthOffset] = 1
+ copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ length := uint16(hdr.UsedLength())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: defaultIPv6HopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ // TODO(stijlist): count this in ICMP stats.
+ return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
+ //
+ // 7. Address Mapping -- Multicast
+ //
+ // An IPv6 packet with a multicast destination address DST,
+ // consisting of the sixteen octets DST[1] through DST[16], is
+ // transmitted to the Ethernet multicast address whose first
+ // two octets are the value 3333 hexadecimal and whose last
+ // four octets are the last four octets of DST.
+ return tcpip.LinkAddress([]byte{
+ 0x33,
+ 0x33,
+ addr[header.IPv6AddressSize-4],
+ addr[header.IPv6AddressSize-3],
+ addr[header.IPv6AddressSize-2],
+ addr[header.IPv6AddressSize-1],
+ }), true
+ }
+ return "", false
+}
+
+func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := header.Checksum([]byte(src), 0)
+ xsum = header.Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
+ xsum = header.Checksum(upperLayerLength[:], xsum)
+ xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
+ for _, v := range vv.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^header.Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
new file mode 100644
index 000000000..4b8cd496b
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -0,0 +1,207 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipv6 contains the implementation of the ipv6 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv6
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv6 protocol name.
+ ProtocolName = "ipv6"
+
+ // ProtocolNumber is the ipv6 protocol number.
+ ProtocolNumber = header.IPv6ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // PayloadLength field of the ipv6 header.
+ maxPayloadSize = 0xffff
+
+ // defaultIPv6HopLimit is the default hop limit for IPv6 Packets
+ // egressed by Netstack.
+ defaultIPv6HopLimit = 255
+)
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+ dispatcher stack.TransportDispatcher
+}
+
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return 255
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv6 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+ length := uint16(hdr.UsedLength() + payload.Size())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(protocol),
+ HopLimit: ttl,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
+ return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+}
+
+// HandlePacket is called by the link layer when new ipv6 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+ headerView := vv.First()
+ h := header.IPv6(headerView)
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ vv.TrimFront(header.IPv6MinimumSize)
+ vv.CapLength(int(h.PayloadLength()))
+
+ p := h.TransportProtocol()
+ if p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, headerView, vv)
+ return
+ }
+
+ r.Stats().IP.PacketsDelivered.Increment()
+ e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (*endpoint) Close() {}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv6 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv6 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv6 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &endpoint{
+ nicid: nicid,
+ id: stack.NetworkEndpointID{LocalAddress: addr},
+ linkEP: linkEP,
+ linkAddrCache: linkAddrCache,
+ dispatcher: dispatcher,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
new file mode 100755
index 000000000..53319e0c4
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ipv6
+
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
new file mode 100644
index 000000000..a1712b590
--- /dev/null
+++ b/pkg/tcpip/ports/ports.go
@@ -0,0 +1,209 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+package ports
+
+import (
+ "math"
+ "math/rand"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ // FirstEphemeral is the first ephemeral port.
+ FirstEphemeral = 16000
+
+ anyIPAddress tcpip.Address = ""
+)
+
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
+}
+
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ mu sync.RWMutex
+ allocatedPorts map[portDescriptor]bindAddresses
+}
+
+type portNode struct {
+ reuse bool
+ refs int
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]portNode
+
+// isAvailable checks whether an IP address is available to bind to.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
+ if addr == anyIPAddress {
+ if len(b) == 0 {
+ return true
+ }
+ if !reuse {
+ return false
+ }
+ for _, n := range b {
+ if !n.reuse {
+ return false
+ }
+ }
+ return true
+ }
+
+ // If all addresses for this portDescriptor are already bound, no
+ // address is available.
+ if n, ok := b[anyIPAddress]; ok {
+ if !reuse {
+ return false
+ }
+ if !n.reuse {
+ return false
+ }
+ }
+
+ if n, ok := b[addr]; ok {
+ if !reuse {
+ return false
+ }
+ return n.reuse
+ }
+ return true
+}
+
+// NewPortManager creates new PortManager.
+func NewPortManager() *PortManager {
+ return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
+}
+
+// PickEphemeralPort randomly chooses a starting point and iterates over all
+// possible ephemeral ports, allowing the caller to decide whether a given port
+// is suitable for its needs, and stopping when a port is found or an error
+// occurs.
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ count := uint16(math.MaxUint16 - FirstEphemeral + 1)
+ offset := uint16(rand.Int31n(int32(count)))
+
+ for i := uint16(0); i < count; i++ {
+ port = FirstEphemeral + (offset+i)%count
+ ok, err := testPort(port)
+ if err != nil {
+ return 0, err
+ }
+
+ if ok {
+ return port, nil
+ }
+ }
+
+ return 0, tcpip.ErrNoPortAvailable
+}
+
+// IsPortAvailable tests if the given port is available on all given protocols.
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+}
+
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if addrs, ok := s.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(addr, reuse) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// ReservePort marks a port/IP combination as reserved so that it cannot be
+// reserved by another endpoint. If port is zero, ReservePort will search for
+// an unreserved ephemeral port and reserve it, returning its value in the
+// "port" return value.
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If a port is specified, just try to reserve it for all network
+ // protocols.
+ if port != 0 {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ return 0, tcpip.ErrPortInUse
+ }
+ return port, nil
+ }
+
+ // A port wasn't specified, so try to find one.
+ return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ })
+}
+
+// reserveSpecificPort tries to reserve the given port on all given protocols.
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+ return false
+ }
+
+ // Reserve port on all network protocols.
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ if n, ok := m[addr]; ok {
+ n.refs++
+ m[addr] = n
+ } else {
+ m[addr] = portNode{reuse: reuse, refs: 1}
+ }
+ }
+
+ return true
+}
+
+// ReleasePort releases the reservation on a port/IP combination so that it can
+// be reserved by other endpoints.
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if m, ok := s.allocatedPorts[desc]; ok {
+ n, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n.refs--
+ if n.refs == 0 {
+ delete(m, addr)
+ } else {
+ m[addr] = n
+ }
+ if len(m) == 0 {
+ delete(s.allocatedPorts, desc)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go
new file mode 100755
index 000000000..664cc3e71
--- /dev/null
+++ b/pkg/tcpip/ports/ports_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package ports
+
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
new file mode 100644
index 000000000..b40a3c212
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -0,0 +1,67 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seqnum defines the types and methods for TCP sequence numbers such
+// that they fit in 32-bit words and work properly when overflows occur.
+package seqnum
+
+// Value represents the value of a sequence number.
+type Value uint32
+
+// Size represents the size (length) of a sequence number window.
+type Size uint32
+
+// LessThan checks if v is before w, i.e., v < w.
+func (v Value) LessThan(w Value) bool {
+ return int32(v-w) < 0
+}
+
+// LessThanEq returns true if v==w or v is before i.e., v < w.
+func (v Value) LessThanEq(w Value) bool {
+ if v == w {
+ return true
+ }
+ return v.LessThan(w)
+}
+
+// InRange checks if v is in the range [a,b), i.e., a <= v < b.
+func (v Value) InRange(a, b Value) bool {
+ return v-a < b-a
+}
+
+// InWindow checks if v is in the window that starts at 'first' and spans 'size'
+// sequence numbers.
+func (v Value) InWindow(first Value, size Size) bool {
+ return v.InRange(first, first.Add(size))
+}
+
+// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
+func Overlap(a Value, b Size, x Value, y Size) bool {
+ return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
+}
+
+// Add calculates the sequence number following the [v, v+s) window.
+func (v Value) Add(s Size) Value {
+ return v + Value(s)
+}
+
+// Size calculates the size of the window defined by [v, w).
+func (v Value) Size(w Value) Size {
+ return Size(w - v)
+}
+
+// UpdateForward updates v such that it becomes v + s.
+func (v *Value) UpdateForward(s Size) {
+ *v += Value(s)
+}
diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go
new file mode 100755
index 000000000..bf76f6ac4
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go
@@ -0,0 +1,4 @@
+// automatically generated by stateify.
+
+package seqnum
+
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
new file mode 100644
index 000000000..b952ad20f
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const linkAddrCacheSize = 512 // max cache entries
+
+// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
+//
+// The entries are stored in a ring buffer, oldest entry replaced first.
+//
+// This struct is safe for concurrent use.
+type linkAddrCache struct {
+ // ageLimit is how long a cache entry is valid for.
+ ageLimit time.Duration
+
+ // resolutionTimeout is the amount of time to wait for a link request to
+ // resolve an address.
+ resolutionTimeout time.Duration
+
+ // resolutionAttempts is the number of times an address is attempted to be
+ // resolved before failing.
+ resolutionAttempts int
+
+ mu sync.Mutex
+ cache map[tcpip.FullAddress]*linkAddrEntry
+ next int // array index of next available entry
+ entries [linkAddrCacheSize]linkAddrEntry
+}
+
+// entryState controls the state of a single entry in the cache.
+type entryState int
+
+const (
+ // incomplete means that there is an outstanding request to resolve the
+ // address. This is the initial state.
+ incomplete entryState = iota
+ // ready means that the address has been resolved and can be used.
+ ready
+ // failed means that address resolution timed out and the address
+ // could not be resolved.
+ failed
+ // expired means that the cache entry has expired and the address must be
+ // resolved again.
+ expired
+)
+
+// String implements Stringer.
+func (s entryState) String() string {
+ switch s {
+ case incomplete:
+ return "incomplete"
+ case ready:
+ return "ready"
+ case failed:
+ return "failed"
+ case expired:
+ return "expired"
+ default:
+ return fmt.Sprintf("unknown(%d)", s)
+ }
+}
+
+// A linkAddrEntry is an entry in the linkAddrCache.
+// This struct is thread-compatible.
+type linkAddrEntry struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
+
+ // wakers is a set of waiters for address resolution result. Anytime
+ // state transitions out of 'incomplete' these waiters are notified.
+ wakers map[*sleep.Waker]struct{}
+
+ done chan struct{}
+}
+
+func (e *linkAddrEntry) state() entryState {
+ if e.s != expired && time.Now().After(e.expiration) {
+ // Force the transition to ensure waiters are notified.
+ e.changeState(expired)
+ }
+ return e.s
+}
+
+func (e *linkAddrEntry) changeState(ns entryState) {
+ if e.s == ns {
+ return
+ }
+
+ // Validate state transition.
+ switch e.s {
+ case incomplete:
+ // All transitions are valid.
+ case ready, failed:
+ if ns != expired {
+ panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
+ }
+ case expired:
+ // Terminal state.
+ panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
+ default:
+ panic(fmt.Sprintf("invalid state: %s", e.s))
+ }
+
+ // Notify whoever is waiting on address resolution when transitioning
+ // out of 'incomplete'.
+ if e.s == incomplete {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if e.done != nil {
+ close(e.done)
+ }
+ }
+ e.s = ns
+}
+
+func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
+ if w != nil {
+ e.wakers[w] = struct{}{}
+ }
+}
+
+func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
+ delete(e.wakers, w)
+}
+
+// add adds a k -> v mapping to the cache.
+func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if ok {
+ s := entry.state()
+ if s != expired && entry.linkAddr == v {
+ // Disregard repeated calls.
+ return
+ }
+ // Check if entry is waiting for address resolution.
+ if s == incomplete {
+ entry.linkAddr = v
+ } else {
+ // Otherwise create a new entry to replace it.
+ entry = c.makeAndAddEntry(k, v)
+ }
+ } else {
+ entry = c.makeAndAddEntry(k, v)
+ }
+
+ entry.changeState(ready)
+}
+
+// makeAndAddEntry is a helper function to create and add a new
+// entry to the cache map and evict older entry as needed.
+func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
+ // Take over the next entry.
+ entry := &c.entries[c.next]
+ if c.cache[entry.addr] == entry {
+ delete(c.cache, entry.addr)
+ }
+
+ // Mark the soon-to-be-replaced entry as expired, just in case there is
+ // someone waiting for address resolution on it.
+ entry.changeState(expired)
+
+ *entry = linkAddrEntry{
+ addr: k,
+ linkAddr: v,
+ expiration: time.Now().Add(c.ageLimit),
+ wakers: make(map[*sleep.Waker]struct{}),
+ done: make(chan struct{}),
+ }
+
+ c.cache[k] = entry
+ c.next = (c.next + 1) % len(c.entries)
+ return entry
+}
+
+// get reports any known link address for k.
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+ if linkRes != nil {
+ if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ return addr, nil, nil
+ }
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if entry, ok := c.cache[k]; ok {
+ switch s := entry.state(); s {
+ case expired:
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return "", nil, tcpip.ErrNoLinkAddress
+ case incomplete:
+ // Address resolution is still in progress.
+ entry.maybeAddWaker(waker)
+ return "", entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+ }
+
+ if linkRes == nil {
+ return "", nil, tcpip.ErrNoLinkAddress
+ }
+
+ // Add 'incomplete' entry in the cache to mark that resolution is in progress.
+ e := c.makeAndAddEntry(k, "")
+ e.maybeAddWaker(waker)
+
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+
+ return "", e.done, tcpip.ErrWouldBlock
+}
+
+// removeWaker removes a waker previously added through get().
+func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if entry, ok := c.cache[k]; ok {
+ entry.removeWaker(waker)
+ }
+}
+
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+ for i := 0; ; i++ {
+ // Send link request, then wait for the timeout limit and check
+ // whether the request succeeded.
+ linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+
+ select {
+ case <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(k, i); stop {
+ return
+ }
+ case <-done:
+ return
+ }
+ }
+}
+
+// checkLinkRequest checks whether previous attempt to resolve address has succeeded
+// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
+// can stop, false if another request should be sent.
+func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if !ok {
+ // Entry was evicted from the cache.
+ return true
+ }
+
+ switch s := entry.state(); s {
+ case ready, failed, expired:
+ // Entry was made ready by resolver or failed. Either way we're done.
+ return true
+ case incomplete:
+ if attempt+1 >= c.resolutionAttempts {
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed)
+ return true
+ }
+ // No response yet, need to send another ARP request.
+ return false
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
+}
+
+func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+ return &linkAddrCache{
+ ageLimit: ageLimit,
+ resolutionTimeout: resolutionTimeout,
+ resolutionAttempts: resolutionAttempts,
+ cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
new file mode 100644
index 000000000..50d35de88
--- /dev/null
+++ b/pkg/tcpip/stack/nic.go
@@ -0,0 +1,728 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// NIC represents a "network interface card" to which the networking stack is
+// attached.
+type NIC struct {
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ loopback bool
+
+ demux *transportDemuxer
+
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ subnets []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
+
+ stats NICStats
+}
+
+// NICStats includes transmitted and received stats.
+type NICStats struct {
+ Tx DirectionStats
+ Rx DirectionStats
+}
+
+// DirectionStats includes packet and byte counts.
+type DirectionStats struct {
+ Packets *tcpip.StatCounter
+ Bytes *tcpip.StatCounter
+}
+
+// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, the most recently-added one will be first.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
+ return &NIC{
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ loopback: loopback,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ mcastJoins: make(map[NetworkEndpointID]int32),
+ stats: NICStats{
+ Tx: DirectionStats{
+ Packets: &tcpip.StatCounter{},
+ Bytes: &tcpip.StatCounter{},
+ },
+ Rx: DirectionStats{
+ Packets: &tcpip.StatCounter{},
+ Bytes: &tcpip.StatCounter{},
+ },
+ },
+ }
+}
+
+// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
+// to start delivering packets.
+func (n *NIC) attachLinkEndpoint() {
+ n.linkEP.Attach(n)
+}
+
+// setPromiscuousMode enables or disables promiscuous mode.
+func (n *NIC) setPromiscuousMode(enable bool) {
+ n.mu.Lock()
+ n.promiscuous = enable
+ n.mu.Unlock()
+}
+
+func (n *NIC) isPromiscuousMode() bool {
+ n.mu.RLock()
+ rv := n.promiscuous
+ n.mu.RUnlock()
+ return rv
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.spoofing = enable
+ n.mu.Unlock()
+}
+
+func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var r *referencedNetworkEndpoint
+
+ // Check for a primary endpoint.
+ if list, ok := n.primary[protocol]; ok {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ if ref.holdsInsertRef && ref.tryIncRef() {
+ r = ref
+ break
+ }
+ }
+
+ }
+
+ if r == nil {
+ return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress
+ }
+
+ address := r.ep.ID().LocalAddress
+ r.decRef()
+
+ // Find the least-constrained matching subnet for the address, if one
+ // exists, and return it.
+ var subnet tcpip.Subnet
+ for _, s := range n.subnets {
+ if s.Contains(address) && !subnet.Contains(s.ID()) {
+ subnet = s
+ }
+ }
+ return address, subnet, nil
+}
+
+// primaryEndpoint returns the primary endpoint of n for the given network
+// protocol.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list := n.primary[protocol]
+ if list == nil {
+ return nil
+ }
+
+ for e := list.Front(); e != nil; e = e.Next() {
+ r := e.(*referencedNetworkEndpoint)
+ // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set.
+ switch r.ep.ID().LocalAddress {
+ case header.IPv4Broadcast, header.IPv4Any:
+ continue
+ }
+ if r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
+// findEndpoint finds the endpoint, if any, with the given address.
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{address}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ spoofing := n.spoofing
+ n.mu.RUnlock()
+
+ if ref != nil || !spoofing {
+ return ref
+ }
+
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, address, peb, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ return ref
+}
+
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
+
+ id := *ep.ID()
+ if ref, ok := n.endpoints[id]; ok {
+ if !replace {
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ n.removeEndpointLocked(ref)
+ }
+
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocol,
+ holdsInsertRef: true,
+ }
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
+ ref.linkCache = n.stack
+ }
+ }
+
+ n.endpoints[id] = ref
+
+ l, ok := n.primary[protocol]
+ if !ok {
+ l = &ilist.List{}
+ n.primary[protocol] = l
+ }
+
+ switch peb {
+ case CanBePrimaryEndpoint:
+ l.PushBack(ref)
+ case FirstPrimaryEndpoint:
+ l.PushFront(ref)
+ }
+
+ return ref, nil
+}
+
+// AddAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+ // Add the endpoint.
+ n.mu.Lock()
+ _, err := n.addAddressLocked(protocol, addr, peb, false)
+ n.mu.Unlock()
+
+ return err
+}
+
+// Addresses returns the addresses associated with this NIC.
+func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
+ for nid, ep := range n.endpoints {
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ep.protocol,
+ Address: nid.LocalAddress,
+ })
+ }
+ return addrs
+}
+
+// AddSubnet adds a new subnet to n, so that it starts accepting packets
+// targeted at the given address and network protocol.
+func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+ n.mu.Lock()
+ n.subnets = append(n.subnets, subnet)
+ n.mu.Unlock()
+}
+
+// RemoveSubnet removes the given subnet from n.
+func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+ n.mu.Lock()
+
+ // Use the same underlying array.
+ tmp := n.subnets[:0]
+ for _, sub := range n.subnets {
+ if sub != subnet {
+ tmp = append(tmp, sub)
+ }
+ }
+ n.subnets = tmp
+
+ n.mu.Unlock()
+}
+
+// ContainsSubnet reports whether this NIC contains the given subnet.
+func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
+ for _, s := range n.Subnets() {
+ if s == subnet {
+ return true
+ }
+ }
+ return false
+}
+
+// Subnets returns the Subnets associated with this NIC.
+func (n *NIC) Subnets() []tcpip.Subnet {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ for nid := range n.endpoints {
+ sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
+ if err != nil {
+ // This should never happen as the mask has been carefully crafted to
+ // match the address.
+ panic("Invalid endpoint subnet: " + err.Error())
+ }
+ sns = append(sns, sn)
+ }
+ return append(sns, n.subnets...)
+}
+
+func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
+ id := *r.ep.ID()
+
+ // Nothing to do if the reference has already been replaced with a
+ // different one.
+ if n.endpoints[id] != r {
+ return
+ }
+
+ if r.holdsInsertRef {
+ panic("Reference count dropped to zero before being removed")
+ }
+
+ delete(n.endpoints, id)
+ wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
+ if wasInList {
+ n.primary[r.protocol].Remove(r)
+ }
+
+ r.ep.Close()
+}
+
+func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
+ n.mu.Lock()
+ n.removeEndpointLocked(r)
+ n.mu.Unlock()
+}
+
+func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
+ r := n.endpoints[NetworkEndpointID{addr}]
+ if r == nil || !r.holdsInsertRef {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ r.holdsInsertRef = false
+
+ r.decRefLocked()
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removeAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ if joins == 0 {
+ if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ switch joins {
+ case 0:
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ case 1:
+ // This is the last one, clean up.
+ if err := n.removeAddressLocked(addr); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins - 1
+ return nil
+}
+
+// DeliverNetworkPacket finds the appropriate network protocol endpoint and
+// hands the packet over for further processing. This function is called when
+// the NIC receives a packet from the physical interface.
+// Note that the ownership of the slice backing vv is retained by the caller.
+// This rule applies only to the slice itself, not to the items of the slice;
+// the ownership of the items is not retained by the caller.
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ n.stats.Rx.Packets.Increment()
+ n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
+
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+
+ if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
+ n.stack.stats.IP.PacketsReceived.Increment()
+ }
+
+ if len(vv.First()) < netProto.MinimumPacketSize() {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ src, dst := netProto.ParseAddresses(vv.First())
+
+ // If the packet is destined to the IPv4 Broadcast address, then make a
+ // route to each IPv4 network endpoint and let each endpoint handle the
+ // packet.
+ if dst == header.IPv4Broadcast {
+ // n.endpoints is mutex protected so acquire lock.
+ n.mu.RLock()
+ for _, ref := range n.endpoints {
+ if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remote
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ }
+ }
+ n.mu.RUnlock()
+ return
+ }
+
+ if ref := n.getRef(protocol, dst); ref != nil {
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remote
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ return
+ }
+
+ // This NIC doesn't care about the packet. Find a NIC that cares about the
+ // packet and forward it to the NIC.
+ //
+ // TODO: Should we be forwarding the packet even if promiscuous?
+ if n.stack.Forwarding() {
+ r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
+ if err != nil {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ return
+ }
+ defer r.Release()
+
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
+
+ // Found a NIC.
+ n := r.ref.nic
+ n.mu.RLock()
+ ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ n.mu.RUnlock()
+ if ok && ref.tryIncRef() {
+ r.RemoteAddress = src
+ // TODO(b/123449044): Update the source NIC as well.
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+ } else {
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ hdr := buffer.NewPrependableFromView(vv.First())
+ vv.RemoveFirst()
+
+ // TODO(b/128629022): use route.WritePacket.
+ if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ } else {
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size()))
+ }
+ }
+ return
+ }
+
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+}
+
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{dst}
+
+ n.mu.RLock()
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+
+ promiscuous := n.promiscuous
+ // Check if the packet is for a subnet this NIC cares about.
+ if !promiscuous {
+ for _, sn := range n.subnets {
+ if sn.Contains(dst) {
+ promiscuous = true
+ break
+ }
+ }
+ }
+ n.mu.RUnlock()
+ if promiscuous {
+ // Try again with the lock in exclusive mode. If we still can't
+ // get the endpoint, create a new "temporary" one. It will only
+ // exist while there's a route through it.
+ n.mu.Lock()
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
+ }
+ ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
+ n.mu.Unlock()
+ if err == nil {
+ ref.holdsInsertRef = false
+ return ref
+ }
+ }
+
+ return nil
+}
+
+// DeliverTransportPacket delivers the packets to the appropriate transport
+// protocol endpoint.
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[protocol]
+ if !ok {
+ n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ return
+ }
+
+ transProto := state.proto
+
+ // Raw socket packets are delivered based solely on the transport
+ // protocol number. We do not inspect the payload to ensure it's
+ // validly formed.
+ if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
+ }
+
+ if len(vv.First()) < transProto.MinimumPacketSize() {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+
+ id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
+ if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ return
+ }
+
+ // Try to deliver to per-stack default handler.
+ if state.defaultHandler != nil {
+ if state.defaultHandler(r, id, netHeader, vv) {
+ return
+ }
+ }
+
+ // We could not find an appropriate destination for this packet, so
+ // deliver it to the global handler.
+ if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ if len(vv.First()) < 8 {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+}
+
+// ID returns the identifier of n.
+func (n *NIC) ID() tcpip.NICID {
+ return n.id
+}
+
+type referencedNetworkEndpoint struct {
+ ilist.Entry
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
+ linkCache LinkAddressCache
+
+ // holdsInsertRef is protected by the NIC's mutex. It indicates whether
+ // the reference count is biased by 1 due to the insertion of the
+ // endpoint. It is reset to false when RemoveAddress is called on the
+ // NIC.
+ holdsInsertRef bool
+}
+
+// decRef decrements the ref count and cleans up the endpoint once it reaches
+// zero.
+func (r *referencedNetworkEndpoint) decRef() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpoint(r)
+ }
+}
+
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
+// incRef increments the ref count. It must only be called when the caller is
+// known to be holding a reference to the endpoint, otherwise tryIncRef should
+// be used.
+func (r *referencedNetworkEndpoint) incRef() {
+ atomic.AddInt32(&r.refs, 1)
+}
+
+// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
+// not zero. That is, it will increment the count if the endpoint is still
+// alive, and do nothing if it has already been clean up.
+func (r *referencedNetworkEndpoint) tryIncRef() bool {
+ for {
+ v := atomic.LoadInt32(&r.refs)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
+ return true
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
new file mode 100644
index 000000000..c70533a35
--- /dev/null
+++ b/pkg/tcpip/stack/registration.go
@@ -0,0 +1,441 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// NetworkEndpointID is the identifier of a network layer protocol endpoint.
+// Currently the local address is sufficient because all supported protocols
+// (i.e., IPv4 and IPv6) have different sizes for their addresses.
+type NetworkEndpointID struct {
+ LocalAddress tcpip.Address
+}
+
+// TransportEndpointID is the identifier of a transport layer protocol endpoint.
+//
+// +stateify savable
+type TransportEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// ControlType is the type of network control message.
+type ControlType int
+
+// The following are the allowed values for ControlType values.
+const (
+ ControlPacketTooBig ControlType = iota
+ ControlPortUnreachable
+ ControlUnknown
+)
+
+// TransportEndpoint is the interface that needs to be implemented by transport
+// protocol (e.g., tcp, udp) endpoints that can handle packets.
+type TransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint.
+ HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+
+ // HandleControlPacket is called by the stack when new control (e.g.,
+ // ICMP) packets arrive to this transport endpoint.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+}
+
+// RawTransportEndpoint is the interface that needs to be implemented by raw
+// transport protocol endpoints. RawTransportEndpoints receive the entire
+// packet - including the link, network, and transport headers - as delivered
+// to netstack.
+type RawTransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint. The packet contains all data from the link
+ // layer up.
+ HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+}
+
+// TransportProtocol is the interface that needs to be implemented by transport
+// protocols (e.g., tcp, udp) that want to be part of the networking stack.
+type TransportProtocol interface {
+ // Number returns the transport protocol number.
+ Number() tcpip.TransportProtocolNumber
+
+ // NewEndpoint creates a new endpoint of the transport protocol.
+ NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewRawEndpoint creates a new raw endpoint of the transport protocol.
+ NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // transport protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination ports stored in a
+ // packet of this protocol.
+ ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+
+ // HandleUnknownDestinationPacket handles packets targeted at this
+ // protocol but that don't match any existing endpoint. For example,
+ // it is targeted at a port that have no listeners.
+ //
+ // The return value indicates whether the packet was well-formed (for
+ // stats purposes only).
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// TransportDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate transport endpoint after it has been handled by
+// the network layer.
+type TransportDispatcher interface {
+ // DeliverTransportPacket delivers packets to the appropriate
+ // transport protocol endpoint. It also returns the network layer
+ // header for the enpoint to inspect or pass up the stack.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
+
+ // DeliverTransportControlPacket delivers control packets to the
+ // appropriate transport protocol endpoint.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
+}
+
+// PacketLooping specifies where an outbound packet should be sent.
+type PacketLooping byte
+
+const (
+ // PacketOut indicates that the packet should be passed to the link
+ // endpoint.
+ PacketOut PacketLooping = 1 << iota
+
+ // PacketLoop indicates that the packet should be handled locally.
+ PacketLoop
+)
+
+// NetworkEndpoint is the interface that needs to be implemented by endpoints
+// of network layer protocols (e.g., ipv4, ipv6).
+type NetworkEndpoint interface {
+ // DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
+ // for this endpoint.
+ DefaultTTL() uint8
+
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // generally calculated as the MTU of the underlying data link endpoint
+ // minus the network endpoint max header length.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // underlying link-layer endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the network (and lower
+ // level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // WritePacket writes a packet to the given destination address and
+ // protocol.
+ WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
+
+ // ID returns the network protocol endpoint ID.
+ ID() *NetworkEndpointID
+
+ // NICID returns the id of the NIC this endpoint belongs to.
+ NICID() tcpip.NICID
+
+ // HandlePacket is called by the link layer when new packets arrive to
+ // this network endpoint.
+ HandlePacket(r *Route, vv buffer.VectorisedView)
+
+ // Close is called when the endpoint is reomved from a stack.
+ Close()
+}
+
+// NetworkProtocol is the interface that needs to be implemented by network
+// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
+type NetworkProtocol interface {
+ // Number returns the network protocol number.
+ Number() tcpip.NetworkProtocolNumber
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // network protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination addresses stored in a
+ // packet of this protocol.
+ ParseAddresses(v buffer.View) (src, dst tcpip.Address)
+
+ // NewEndpoint creates a new endpoint of this protocol.
+ NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// NetworkDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate network endpoint after it has been handled by
+// the data link layer.
+type NetworkDispatcher interface {
+ // DeliverNetworkPacket finds the appropriate network protocol
+ // endpoint and hands the packet over for further processing.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+}
+
+// LinkEndpointCapabilities is the type associated with the capabilities
+// supported by a link-layer endpoint. It is a set of bitfields.
+type LinkEndpointCapabilities uint
+
+// The following are the supported link endpoint capabilities.
+const (
+ CapabilityNone LinkEndpointCapabilities = 0
+ // CapabilityTXChecksumOffload indicates that the link endpoint supports
+ // checksum computation for outgoing packets and the stack can skip
+ // computing checksums when sending packets.
+ CapabilityTXChecksumOffload LinkEndpointCapabilities = 1 << iota
+ // CapabilityRXChecksumOffload indicates that the link endpoint supports
+ // checksum verification on received packets and that it's safe for the
+ // stack to skip checksum verification.
+ CapabilityRXChecksumOffload
+ CapabilityResolutionRequired
+ CapabilitySaveRestore
+ CapabilityDisconnectOk
+ CapabilityLoopback
+ CapabilityGSO
+)
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint.
+type LinkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // usually dictated by the backing physical network; when such a
+ // physical network doesn't exist, the limit is generally 64k, which
+ // includes the maximum size of an IP packet.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the data link (and
+ // lower level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // LinkAddress returns the link address (typically a MAC) of the
+ // link endpoint.
+ LinkAddress() tcpip.LinkAddress
+
+ // WritePacket writes a packet with the given protocol through the given
+ // route.
+ //
+ // To participate in transparent bridging, a LinkEndpoint implementation
+ // should call eth.Encode with header.EthernetFields.SrcAddr set to
+ // r.LocalLinkAddress if it is provided.
+ WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+
+ // Attach attaches the data link layer endpoint to the network-layer
+ // dispatcher of the stack.
+ Attach(dispatcher NetworkDispatcher)
+
+ // IsAttached returns whether a NetworkDispatcher is attached to the
+ // endpoint.
+ IsAttached() bool
+}
+
+// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
+// delivered via the Inject method.
+type InjectableLinkEndpoint interface {
+ LinkEndpoint
+
+ // Inject injects an inbound packet.
+ Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+
+ // WriteRawPacket writes a fully formed outbound packet directly to the link.
+ //
+ // dest is used by endpoints with multiple raw destinations.
+ WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error
+}
+
+// A LinkAddressResolver is an extension to a NetworkProtocol that
+// can resolve link addresses.
+type LinkAddressResolver interface {
+ // LinkAddressRequest sends a request for the LinkAddress of addr.
+ // The request is sent on linkEP with localAddr as the source.
+ //
+ // A valid response will cause the discovery protocol's network
+ // endpoint to call AddLinkAddress.
+ LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+
+ // ResolveStaticAddress attempts to resolve address without sending
+ // requests. It either resolves the name immediately or returns the
+ // empty LinkAddress.
+ //
+ // It can be used to resolve broadcast addresses for example.
+ ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
+
+ // LinkAddressProtocol returns the network protocol of the
+ // addresses this this resolver can resolve.
+ LinkAddressProtocol() tcpip.NetworkProtocolNumber
+}
+
+// A LinkAddressCache caches link addresses.
+type LinkAddressCache interface {
+ // CheckLocalAddress determines if the given local address exists, and if it
+ // does not exist.
+ CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+
+ // AddLinkAddress adds a link address to the cache.
+ AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+
+ // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
+ // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
+ // registered with the network protocol, the cache attempts to resolve the address
+ // and returns ErrWouldBlock. Waker is notified when address resolution is
+ // complete (success or not).
+ //
+ // If address resolution is required, ErrNoLinkAddress and a notification channel is
+ // returned for the top level caller to block. Channel is closed once address resolution
+ // is complete (success or not).
+ GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
+
+ // RemoveWaker removes a waker that has been added in GetLinkAddress().
+ RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+}
+
+// TransportProtocolFactory functions are used by the stack to instantiate
+// transport protocols.
+type TransportProtocolFactory func() TransportProtocol
+
+// NetworkProtocolFactory provides methods to be used by the stack to
+// instantiate network protocols.
+type NetworkProtocolFactory func() NetworkProtocol
+
+var (
+ transportProtocols = make(map[string]TransportProtocolFactory)
+ networkProtocols = make(map[string]NetworkProtocolFactory)
+
+ linkEPMu sync.RWMutex
+ nextLinkEndpointID tcpip.LinkEndpointID = 1
+ linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
+)
+
+// RegisterTransportProtocolFactory registers a new transport protocol factory
+// with the stack so that it becomes available to users of the stack. This
+// function is intended to be called by init() functions of the protocols.
+func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
+ transportProtocols[name] = p
+}
+
+// RegisterNetworkProtocolFactory registers a new network protocol factory with
+// the stack so that it becomes available to users of the stack. This function
+// is intended to be called by init() functions of the protocols.
+func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
+ networkProtocols[name] = p
+}
+
+// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
+// ID that can be used to refer to it.
+func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
+ linkEPMu.Lock()
+ defer linkEPMu.Unlock()
+
+ v := nextLinkEndpointID
+ nextLinkEndpointID++
+
+ linkEndpoints[v] = linkEP
+
+ return v
+}
+
+// FindLinkEndpoint finds the link endpoint associated with the given ID.
+func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
+ linkEPMu.RLock()
+ defer linkEPMu.RUnlock()
+
+ return linkEndpoints[id]
+}
+
+// GSOType is the type of GSO segments.
+//
+// +stateify savable
+type GSOType int
+
+// Types of gso segments.
+const (
+ GSONone GSOType = iota
+ GSOTCPv4
+ GSOTCPv6
+)
+
+// GSO contains generic segmentation offload properties.
+//
+// +stateify savable
+type GSO struct {
+ // Type is one of GSONone, GSOTCPv4, etc.
+ Type GSOType
+ // NeedsCsum is set if the checksum offload is enabled.
+ NeedsCsum bool
+ // CsumOffset is offset after that to place checksum.
+ CsumOffset uint16
+
+ // Mss is maximum segment size.
+ MSS uint16
+ // L3Len is L3 (IP) header length.
+ L3HdrLen uint16
+
+ // MaxSize is maximum GSO packet size.
+ MaxSize uint32
+}
+
+// GSOEndpoint provides access to GSO properties.
+type GSOEndpoint interface {
+ // GSOMaxSize returns the maximum GSO packet size.
+ GSOMaxSize() uint32
+}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
new file mode 100644
index 000000000..3d4c282a9
--- /dev/null
+++ b/pkg/tcpip/stack/route.go
@@ -0,0 +1,189 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// Route represents a route through the networking stack to a given destination.
+type Route struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the
+ // final destination of the route.
+ RemoteLinkAddress tcpip.LinkAddress
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // ref a reference to the network endpoint through which the route
+ // starts.
+ ref *referencedNetworkEndpoint
+
+ // loop controls where WritePacket should send packets.
+ loop PacketLooping
+}
+
+// makeRoute initializes a new route. It takes ownership of the provided
+// reference to a network endpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route {
+ loop := PacketOut
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ }
+
+ return Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: localLinkAddr,
+ RemoteAddress: remoteAddr,
+ ref: ref,
+ loop: loop,
+ }
+}
+
+// NICID returns the id of the NIC from which this route originates.
+func (r *Route) NICID() tcpip.NICID {
+ return r.ref.ep.NICID()
+}
+
+// MaxHeaderLength forwards the call to the network endpoint's implementation.
+func (r *Route) MaxHeaderLength() uint16 {
+ return r.ref.ep.MaxHeaderLength()
+}
+
+// Stats returns a mutable copy of current stats.
+func (r *Route) Stats() tcpip.Stats {
+ return r.ref.nic.stack.Stats()
+}
+
+// PseudoHeaderChecksum forwards the call to the network endpoint's
+// implementation.
+func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, totalLen uint16) uint16 {
+ return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
+}
+
+// Capabilities returns the link-layer capabilities of the route.
+func (r *Route) Capabilities() LinkEndpointCapabilities {
+ return r.ref.ep.Capabilities()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (r *Route) GSOMaxSize() uint32 {
+ if gso, ok := r.ref.ep.(GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
+// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
+// notified when address resolution is complete (success or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification channel is
+// returned for the top level caller to block. Channel is closed once address resolution
+// is complete (success or not).
+func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
+ if !r.IsResolutionRequired() {
+ // Nothing to do if there is no cache (which does the resolution on cache miss) or
+ // link address is already known.
+ return nil, nil
+ }
+
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ // Local link address is already known.
+ if r.RemoteAddress == r.LocalAddress {
+ r.RemoteLinkAddress = r.LocalLinkAddress
+ return nil, nil
+ }
+ nextAddr = r.RemoteAddress
+ }
+ linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ if err != nil {
+ return ch, err
+ }
+ r.RemoteLinkAddress = linkAddr
+ return nil, nil
+}
+
+// RemoveWaker removes a waker that has been added in Resolve().
+func (r *Route) RemoveWaker(waker *sleep.Waker) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+}
+
+// IsResolutionRequired returns true if Resolve() must be called to resolve
+// the link address before the this route can be written to.
+func (r *Route) IsResolutionRequired() bool {
+ return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+}
+
+// WritePacket writes the packet through the given route.
+func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ } else {
+ r.ref.nic.stats.Tx.Packets.Increment()
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size()))
+ }
+ return err
+}
+
+// DefaultTTL returns the default TTL of the underlying network endpoint.
+func (r *Route) DefaultTTL() uint8 {
+ return r.ref.ep.DefaultTTL()
+}
+
+// MTU returns the MTU of the underlying network endpoint.
+func (r *Route) MTU() uint32 {
+ return r.ref.ep.MTU()
+}
+
+// Release frees all resources associated with the route.
+func (r *Route) Release() {
+ if r.ref != nil {
+ r.ref.decRef()
+ r.ref = nil
+ }
+}
+
+// Clone Clone a route such that the original one can be released and the new
+// one will remain valid.
+func (r *Route) Clone() Route {
+ r.ref.incRef()
+ return *r
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
new file mode 100644
index 000000000..9d8e8cda5
--- /dev/null
+++ b/pkg/tcpip/stack/stack.go
@@ -0,0 +1,1095 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package stack provides the glue between networking protocols and the
+// consumers of the networking stack.
+//
+// For consumers, the only function of interest is New(), everything else is
+// provided by the tcpip/public package.
+//
+// For protocol implementers, RegisterTransportProtocolFactory() and
+// RegisterNetworkProtocolFactory() are used to register protocol factories with
+// the stack, which will then be used to instantiate protocol objects when
+// consumers interact with the stack.
+package stack
+
+import (
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/ports"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ageLimit is set to the same cache stale time used in Linux.
+ ageLimit = 1 * time.Minute
+ // resolutionTimeout is set to the same ARP timeout used in Linux.
+ resolutionTimeout = 1 * time.Second
+ // resolutionAttempts is set to the same ARP retries used in Linux.
+ resolutionAttempts = 3
+)
+
+type transportProtocolState struct {
+ proto TransportProtocol
+ defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+}
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+type TCPCubicState struct {
+ WLastMax float64
+ WMax float64
+ T time.Time
+ TimeSinceLastCongestion time.Duration
+ C float64
+ K float64
+ Beta float64
+ WC float64
+ WEst float64
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ MaxCwnd int
+
+ // HighRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ HighRxt seqnum.Value
+
+ // RescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ RescueRxt seqnum.Value
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for
+// a given TCP endpoint.
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is the TCP variable RCV.ACC.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive
+ // queue.
+ PendingBufUsed seqnum.Size
+
+ // PendingBufSize is the size of the socket receive buffer.
+ PendingBufSize seqnum.Size
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for
+// a given TCP Endpoint.
+type TCPSenderState struct {
+ // LastSendTime is the time at which we sent the last segment.
+ LastSendTime time.Time
+
+ // DupAckCount is the number of Duplicate ACK's received.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the slow start threshold in packets.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets consumed in congestion
+ // avoidance mode.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets in flight.
+ Outstanding int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time
+
+ // Closed indicates that the caller has closed the endpoint for sending.
+ Closed bool
+
+ // SRTT is the smoothed round-trip time as defined in section 2 of
+ // RFC 6298.
+ SRTT time.Duration
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
+ RTO time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates take a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+
+ // MaxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK Blocks that identify the out of order segments
+ // held by a given TCP endpoint.
+ Blocks []header.SACKBlock
+
+ // ReceivedBlocks are the SACK blocks received by this endpoint
+ // from the peer endpoint.
+ ReceivedBlocks []header.SACKBlock
+
+ // MaxSACKED is the highest sequence number that has been SACKED
+ // by the peer.
+ MaxSACKED seqnum.Value
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+type TCPEndpointState struct {
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time
+
+ // RcvBufSize is the size of the receive socket buffer for the endpoint.
+ RcvBufSize int
+
+ // RcvBufUsed is the amount of bytes actually held in the receive socket
+ // buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvClosed if true, indicates the endpoint has been closed for reading.
+ RcvClosed bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ RecentTS uint32
+
+ // TSOffset is a randomized offset added to the value of the TSVal field
+ // in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+
+ // Receiver holds variables related to the TCP receiver for the endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
+
+// Stack is a networking stack, with all supported protocols, NICs, and route
+// table.
+type Stack struct {
+ transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
+ networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+
+ demux *transportDemuxer
+
+ stats tcpip.Stats
+
+ linkAddrCache *linkAddrCache
+
+ // raw indicates whether raw sockets may be created. It is set during
+ // Stack creation and is immutable.
+ raw bool
+
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+
+ // route is the route table passed in by the user via SetRouteTable(),
+ // it is used by FindRoute() to build a route for a specific
+ // destination.
+ routeTable []tcpip.Route
+
+ *ports.PortManager
+
+ // If not nil, then any new endpoints will have this probe function
+ // invoked everytime they receive a TCP segment.
+ tcpProbeFunc TCPProbeFunc
+
+ // clock is used to generate user-visible times.
+ clock tcpip.Clock
+
+ // handleLocal allows non-loopback interfaces to loop packets.
+ handleLocal bool
+}
+
+// Options contains optional Stack configuration.
+type Options struct {
+ // Clock is an optional clock source used for timestampping packets.
+ //
+ // If no Clock is specified, the clock source will be time.Now.
+ Clock tcpip.Clock
+
+ // Stats are optional statistic counters.
+ Stats tcpip.Stats
+
+ // HandleLocal indicates whether packets destined to their source
+ // should be handled by the stack internally (true) or outside the
+ // stack (false).
+ HandleLocal bool
+
+ // Raw indicates whether raw sockets may be created.
+ Raw bool
+}
+
+// New allocates a new networking stack with only the requested networking and
+// transport protocols configured with default options.
+//
+// Protocol options can be changed by calling the
+// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
+// stack. Please refer to individual protocol implementations as to what options
+// are supported.
+func New(network []string, transport []string, opts Options) *Stack {
+ clock := opts.Clock
+ if clock == nil {
+ clock = &tcpip.StdClock{}
+ }
+
+ s := &Stack{
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ raw: opts.Raw,
+ }
+
+ // Add specified network protocols.
+ for _, name := range network {
+ netProtoFactory, ok := networkProtocols[name]
+ if !ok {
+ continue
+ }
+ netProto := netProtoFactory()
+ s.networkProtocols[netProto.Number()] = netProto
+ if r, ok := netProto.(LinkAddressResolver); ok {
+ s.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ }
+ }
+
+ // Add specified transport protocols.
+ for _, name := range transport {
+ transProtoFactory, ok := transportProtocols[name]
+ if !ok {
+ continue
+ }
+ transProto := transProtoFactory()
+ s.transportProtocols[transProto.Number()] = &transportProtocolState{
+ proto: transProto,
+ }
+ }
+
+ // Create the global transport demuxer.
+ s.demux = newTransportDemuxer(s)
+
+ return s
+}
+
+// SetNetworkProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.SetOption(option)
+}
+
+// NetworkProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// e.g.
+// var v ipv4.MyOption
+// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
+// if err != nil {
+// ...
+// }
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.Option(option)
+}
+
+// SetTransportProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.SetOption(option)
+}
+
+// TransportProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// var v tcp.SACKEnabled
+// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
+// ...
+// }
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.Option(option)
+}
+
+// SetTransportProtocolHandler sets the per-stack default handler for the given
+// protocol.
+//
+// It must be called only during initialization of the stack. Changing it as the
+// stack is operating is not supported.
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
+ state := s.transportProtocols[p]
+ if state != nil {
+ state.defaultHandler = h
+ }
+}
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (s *Stack) NowNanoseconds() int64 {
+ return s.clock.NowNanoseconds()
+}
+
+// Stats returns a mutable copy of the current stats.
+//
+// This is not generally exported via the public interface, but is available
+// internally.
+func (s *Stack) Stats() tcpip.Stats {
+ return s.stats
+}
+
+// SetForwarding enables or disables the packet forwarding between NICs.
+func (s *Stack) SetForwarding(enable bool) {
+ // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.Lock()
+ s.forwarding = enable
+ s.mu.Unlock()
+}
+
+// Forwarding returns if the packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding() bool {
+ // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.forwarding
+}
+
+// SetRouteTable assigns the route table to be used by this stack. It
+// specifies which NIC to use for given destination address ranges.
+func (s *Stack) SetRouteTable(table []tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.routeTable = table
+}
+
+// GetRouteTable returns the route table which is currently in use.
+func (s *Stack) GetRouteTable() []tcpip.Route {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return append([]tcpip.Route(nil), s.routeTable...)
+}
+
+// NewEndpoint creates a new transport layer endpoint of the given protocol.
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewEndpoint(s, network, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw transport layer endpoint of the given
+// protocol. Raw endpoints receive all traffic for a given protocol regardless
+// of address.
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if !s.raw {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewRawEndpoint(s, network, waiterQueue)
+}
+
+// createNIC creates a NIC with the provided id and link-layer endpoint, and
+// optionally enable it.
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
+ ep := FindLinkEndpoint(linkEP)
+ if ep == nil {
+ return tcpip.ErrBadLinkEndpoint
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Make sure id is unique.
+ if _, ok := s.nics[id]; ok {
+ return tcpip.ErrDuplicateNICID
+ }
+
+ n := newNIC(s, id, name, ep, loopback)
+
+ s.nics[id] = n
+ if enabled {
+ n.attachLinkEndpoint()
+ }
+
+ return nil
+}
+
+// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, true, false)
+}
+
+// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
+// and a human-readable name.
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true, false)
+}
+
+// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
+// endpoint, and a human-readable name.
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true, true)
+}
+
+// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
+// but leave it disable. Stack.EnableNIC must be called before the link-layer
+// endpoint starts delivering packets to it.
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, false, false)
+}
+
+// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
+// CreateDisabledNIC.
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, false, false)
+}
+
+// EnableNIC enables the given NIC so that the link-layer endpoint can start
+// delivering packets to it.
+func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.attachLinkEndpoint()
+
+ return nil
+}
+
+// CheckNIC checks if a NIC is usable.
+func (s *Stack) CheckNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+ if ok {
+ return nic.linkEP.IsAttached()
+ }
+ return false
+}
+
+// NICSubnets returns a map of NICIDs to their associated subnets.
+func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := map[tcpip.NICID][]tcpip.Subnet{}
+
+ for id, nic := range s.nics {
+ nics[id] = append(nics[id], nic.Subnets()...)
+ }
+ return nics
+}
+
+// NICInfo captures the name and addresses assigned to a NIC.
+type NICInfo struct {
+ Name string
+ LinkAddress tcpip.LinkAddress
+ ProtocolAddresses []tcpip.ProtocolAddress
+
+ // Flags indicate the state of the NIC.
+ Flags NICStateFlags
+
+ // MTU is the maximum transmission unit.
+ MTU uint32
+
+ Stats NICStats
+}
+
+// NICInfo returns a map of NICIDs to their associated information.
+func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID]NICInfo)
+ for id, nic := range s.nics {
+ flags := NICStateFlags{
+ Up: true, // Netstack interfaces are always up.
+ Running: nic.linkEP.IsAttached(),
+ Promiscuous: nic.isPromiscuousMode(),
+ Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ }
+ nics[id] = NICInfo{
+ Name: nic.name,
+ LinkAddress: nic.linkEP.LinkAddress(),
+ ProtocolAddresses: nic.Addresses(),
+ Flags: flags,
+ MTU: nic.linkEP.MTU(),
+ Stats: nic.stats,
+ }
+ }
+ return nics
+}
+
+// NICStateFlags holds information about the state of an NIC.
+type NICStateFlags struct {
+ // Up indicates whether the interface is running.
+ Up bool
+
+ // Running indicates whether resources are allocated.
+ Running bool
+
+ // Promiscuous indicates whether the interface is in promiscuous mode.
+ Promiscuous bool
+
+ // Loopback indicates whether the interface is a loopback.
+ Loopback bool
+}
+
+// AddAddress adds a new network-layer address to the specified NIC.
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
+}
+
+// AddAddressWithOptions is the same as AddAddress, but allows you to specify
+// whether the new endpoint can be primary or not.
+func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.AddAddressWithOptions(protocol, addr, peb)
+}
+
+// AddSubnet adds a subnet range to the specified NIC.
+func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.AddSubnet(protocol, subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// RemoveSubnet removes the subnet range from the specified NIC.
+func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.RemoveSubnet(subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// ContainsSubnet reports whether the specified NIC contains the specified
+// subnet.
+func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.ContainsSubnet(subnet), nil
+ }
+
+ return false, tcpip.ErrUnknownNICID
+}
+
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.RemoveAddress(addr)
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// GetMainNICAddress returns the first primary address (and the subnet that
+// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
+// address if no primary addresses exist. Returns an error if the NIC doesn't
+// exist or has no endpoints.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.getMainNICAddress(protocol)
+ }
+
+ return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID
+}
+
+func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+ if len(localAddr) == 0 {
+ return nic.primaryEndpoint(netProto)
+ }
+ return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
+}
+
+// FindRoute creates a route to the given destination address, leaving through
+// the given nic and local address (if provided).
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ isBroadcast := remoteAddr == header.IPv4Broadcast
+ isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
+ needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ if id != 0 && !needRoute {
+ if nic, ok := s.nics[id]; ok {
+ if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil
+ }
+ }
+ } else {
+ for _, route := range s.routeTable {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) {
+ continue
+ }
+ if nic, ok := s.nics[route.NIC]; ok {
+ if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ if len(remoteAddr) == 0 {
+ // If no remote address was provided, then the route
+ // provided will refer to the link local address.
+ remoteAddr = ref.ep.ID().LocalAddress
+ }
+
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback)
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ return r, nil
+ }
+ }
+ }
+ }
+
+ if !needRoute {
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+}
+
+// CheckNetworkProtocol checks if a given network protocol is enabled in the
+// stack.
+func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool {
+ _, ok := s.networkProtocols[protocol]
+ return ok
+}
+
+// CheckLocalAddress determines if the given local address exists, and if it
+// does, returns the id of the NIC it's bound to. Returns 0 if the address
+// does not exist.
+func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // If a NIC is specified, we try to find the address there only.
+ if nicid != 0 {
+ nic := s.nics[nicid]
+ if nic == nil {
+ return 0
+ }
+
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if ref == nil {
+ return 0
+ }
+
+ ref.decRef()
+
+ return nic.id
+ }
+
+ // Go through all the NICs.
+ for _, nic := range s.nics {
+ ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if ref != nil {
+ ref.decRef()
+ return nic.id
+ }
+ }
+
+ return 0
+}
+
+// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setPromiscuousMode(enable)
+
+ return nil
+}
+
+// SetSpoofing enables or disables address spoofing in the given NIC, allowing
+// endpoints to bind to any address in the NIC.
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setSpoofing(enable)
+
+ return nil
+}
+
+// AddLinkAddress adds a link address to the stack link cache.
+func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.add(fullAddr, linkAddr)
+ // TODO: provide a way for a transport endpoint to receive a signal
+ // that AddLinkAddress for a particular address has been called.
+}
+
+// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+ s.mu.RLock()
+ nic := s.nics[nicid]
+ if nic == nil {
+ s.mu.RUnlock()
+ return "", nil, tcpip.ErrUnknownNICID
+ }
+ s.mu.RUnlock()
+
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ linkRes := s.linkAddrResolvers[protocol]
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+}
+
+// RemoveWaker implements LinkAddressCache.RemoveWaker.
+func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic := s.nics[nicid]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.removeWaker(fullAddr, waker)
+ }
+}
+
+// RegisterTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided id will be
+// delivered to the given endpoint; specifying a nic is optional, but
+// nic-specific IDs have precedence over global ones.
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+}
+
+// UnregisterTransportEndpoint removes the endpoint with the given id from the
+// stack transport dispatcher.
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
+ }
+}
+
+// RegisterRawTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided transport
+// protocol will be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+}
+
+// UnregisterRawTransportEndpoint removes the endpoint for the transport
+// protocol from the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ if nicID == 0 {
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
+ }
+}
+
+// NetworkProtocolInstance returns the protocol instance in the stack for the
+// specified network protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol {
+ if p, ok := s.networkProtocols[num]; ok {
+ return p
+ }
+ return nil
+}
+
+// TransportProtocolInstance returns the protocol instance in the stack for the
+// specified transport protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol {
+ if pState, ok := s.transportProtocols[num]; ok {
+ return pState.proto
+ }
+ return nil
+}
+
+// AddTCPProbe installs a probe function that will be invoked on every segment
+// received by a given TCP endpoint. The probe function is passed a copy of the
+// TCP endpoint state before and after processing of the segment.
+//
+// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
+// created prior to this call will not call the probe function.
+//
+// Further, installing two different probes back to back can result in some
+// endpoints calling the first one and some the second one. There is no
+// guarantee provided on which probe will be invoked. Ideally this should only
+// be called once per stack.
+func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
+ s.mu.Lock()
+ s.tcpProbeFunc = probe
+ s.mu.Unlock()
+}
+
+// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
+// otherwise.
+func (s *Stack) GetTCPProbe() TCPProbeFunc {
+ s.mu.Lock()
+ p := s.tcpProbeFunc
+ s.mu.Unlock()
+ return p
+}
+
+// RemoveTCPProbe removes an installed TCP probe.
+//
+// NOTE: This only ensures that endpoints created after this call do not
+// have a probe attached. Endpoints already created will continue to invoke
+// TCP probe.
+func (s *Stack) RemoveTCPProbe() {
+ s.mu.Lock()
+ s.tcpProbeFunc = nil
+ s.mu.Unlock()
+}
+
+// JoinGroup joins the given multicast group on the given NIC.
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ // TODO: notify network of subscription via igmp protocol.
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.joinGroup(protocol, multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
+}
+
+// LeaveGroup leaves the given multicast group on the given NIC.
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.leaveGroup(multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
+}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
new file mode 100644
index 000000000..dfec4258a
--- /dev/null
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+// StackFromEnv is the global stack created in restore run.
+// FIXME(b/36201077)
+var StackFromEnv *Stack
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
new file mode 100755
index 000000000..bb05ff7c1
--- /dev/null
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -0,0 +1,59 @@
+// automatically generated by stateify.
+
+package stack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *TransportEndpointID) beforeSave() {}
+func (x *TransportEndpointID) save(m state.Map) {
+ x.beforeSave()
+ m.Save("LocalPort", &x.LocalPort)
+ m.Save("LocalAddress", &x.LocalAddress)
+ m.Save("RemotePort", &x.RemotePort)
+ m.Save("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *TransportEndpointID) afterLoad() {}
+func (x *TransportEndpointID) load(m state.Map) {
+ m.Load("LocalPort", &x.LocalPort)
+ m.Load("LocalAddress", &x.LocalAddress)
+ m.Load("RemotePort", &x.RemotePort)
+ m.Load("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *GSOType) save(m state.Map) {
+ m.SaveValue("", (int)(*x))
+}
+
+func (x *GSOType) load(m state.Map) {
+ m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) })
+}
+
+func (x *GSO) beforeSave() {}
+func (x *GSO) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("NeedsCsum", &x.NeedsCsum)
+ m.Save("CsumOffset", &x.CsumOffset)
+ m.Save("MSS", &x.MSS)
+ m.Save("L3HdrLen", &x.L3HdrLen)
+ m.Save("MaxSize", &x.MaxSize)
+}
+
+func (x *GSO) afterLoad() {}
+func (x *GSO) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("NeedsCsum", &x.NeedsCsum)
+ m.Load("CsumOffset", &x.CsumOffset)
+ m.Load("MSS", &x.MSS)
+ m.Load("L3HdrLen", &x.L3HdrLen)
+ m.Load("MaxSize", &x.MaxSize)
+}
+
+func init() {
+ state.Register("stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load})
+ state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load})
+ state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load})
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
new file mode 100644
index 000000000..605bfadeb
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -0,0 +1,420 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math/rand"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+type protocolIDs struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+}
+
+// transportEndpoints manages all endpoints of a given protocol. It has its own
+// mutex so as to reduce interference between protocols.
+type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
+ mu sync.RWMutex
+ endpoints map[TransportEndpointID]TransportEndpoint
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []RawTransportEndpoint
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ e, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if multiPortEp, ok := e.(*multiPortEndpoint); ok {
+ if !multiPortEp.unregisterEndpoint(ep) {
+ return
+ }
+ }
+ delete(eps.endpoints, id)
+}
+
+// transportDemuxer demultiplexes packets targeted at a transport endpoint
+// (i.e., after they've been parsed by the network layer). It does two levels
+// of demultiplexing: first based on the network and transport protocols, then
+// based on endpoints IDs. It should only be instantiated via
+// newTransportDemuxer.
+type transportDemuxer struct {
+ // protocol is immutable.
+ protocol map[protocolIDs]*transportEndpoints
+}
+
+func newTransportDemuxer(stack *Stack) *transportDemuxer {
+ d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+
+ // Add each network and transport pair to the demuxer.
+ for netProto := range stack.networkProtocols {
+ for proto := range stack.transportProtocols {
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ }
+ }
+ }
+
+ return d
+}
+
+// registerEndpoint registers the given endpoint with the dispatcher such that
+// packets that match the endpoint ID are delivered to it.
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port.
+type multiPortEndpoint struct {
+ mu sync.RWMutex
+ endpointsArr []TransportEndpoint
+ endpointsMap map[TransportEndpoint]int
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+ return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8),
+ }
+
+ h := jenkins.Sum32(ep.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+ hash := h.Sum32()
+
+ idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
+ return ep.endpointsArr[idx]
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints managed by ep.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
+ }
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ }
+ } else {
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+}
+
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // A new endpoint is added into endpointsArr and its index there is
+ // saved in endpointsMap. This will allows to remove endpoint from
+ // the array fast.
+ ep.endpointsMap[t] = len(ep.endpointsArr)
+ ep.endpointsArr = append(ep.endpointsArr, t)
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ idx, ok := ep.endpointsMap[t]
+ if !ok {
+ return false
+ }
+ delete(ep.endpointsMap, t)
+ l := len(ep.endpointsArr)
+ if l > 1 {
+ // The last endpoint in endpointsArr is moved instead of the deleted one.
+ lastEp := ep.endpointsArr[l-1]
+ ep.endpointsArr[idx] = lastEp
+ ep.endpointsMap[lastEp] = idx
+ ep.endpointsArr = ep.endpointsArr[0 : l-1]
+ return false
+ }
+ return true
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if id.RemotePort != 0 {
+ reusePort = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+
+ var multiPortEp *multiPortEndpoint
+ if _, ok := eps.endpoints[id]; ok {
+ if !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
+ if !ok {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if reusePort {
+ if multiPortEp == nil {
+ multiPortEp = &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.seed = rand.Uint32()
+ eps.endpoints[id] = multiPortEp
+ }
+
+ multiPortEp.singleRegisterEndpoint(ep)
+
+ return nil
+ }
+ eps.endpoints[id] = ep
+
+ return nil
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.unregisterEndpoint(id, ep)
+ }
+ }
+}
+
+var loopbackSubnet = func() tcpip.Subnet {
+ sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ panic(err)
+ }
+ return sn
+}()
+
+// deliverPacket attempts to find one or more matching transport endpoints, and
+// then, if matches are found, delivers the packet to them. Returns true if it
+// found one or more endpoints, false otherwise.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // If a sender bound to the Loopback interface sends a broadcast,
+ // that broadcast must not be delivered to the sender.
+ if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
+ return false
+ }
+
+ // If the packet is a broadcast, then find all matching transport endpoints.
+ // Otherwise, try to find a single matching transport endpoint.
+ destEps := make([]TransportEndpoint, 0, 1)
+ eps.mu.RLock()
+
+ if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
+ for epID, endpoint := range eps.endpoints {
+ if epID.LocalPort == id.LocalPort {
+ destEps = append(destEps, endpoint)
+ }
+ }
+ } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
+
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEps) == 0 {
+ // UDP packet could not be delivered to an unknown destination port.
+ if protocol == header.UDPProtocolNumber {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ }
+ return false
+ }
+
+ // Deliver the packet.
+ for _, ep := range destEps {
+ ep.HandlePacket(r, id, vv)
+ }
+
+ return true
+}
+
+// deliverRawPacket attempts to deliver the given packet and returns whether it
+// was delivered successfully.
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multiple raw endpoints, they all
+ // receive the packet.
+ foundRaw := false
+ eps.mu.RLock()
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ foundRaw = true
+ }
+ eps.mu.RUnlock()
+
+ return foundRaw
+}
+
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandleControlPacket(id, typ, extra, vv)
+
+ return true
+}
+
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ return ep
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ return nil
+}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given transport
+// protocol such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
+ }
+ }
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
new file mode 100644
index 000000000..f9886c6e4
--- /dev/null
+++ b/pkg/tcpip/tcpip.go
@@ -0,0 +1,1055 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcpip provides the interfaces and related types that users of the
+// tcpip stack will use in order to create endpoints used to send and receive
+// data over the network stack.
+//
+// The starting point is the creation and configuration of a stack. A stack can
+// be created by calling the New() function of the tcpip/stack/stack package;
+// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
+// adding network addresses (via calls to Stack.AddAddress()), and
+// setting a route table (via a call to Stack.SetRouteTable()).
+//
+// Once a stack is configured, endpoints can be created by calling
+// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect
+// to peers, listen for connections, accept connections, etc., depending on the
+// transport protocol selected.
+package tcpip
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Error represents an error in the netstack error space. Using a special type
+// ensures that errors outside of this space are not accidentally introduced.
+//
+// Note: to support save / restore, it is important that all tcpip errors have
+// distinct error messages.
+type Error struct {
+ msg string
+
+ ignoreStats bool
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ return e.msg
+}
+
+// IgnoreStats indicates whether this error type should be included in failure
+// counts in tcpip.Stats structs.
+func (e *Error) IgnoreStats() bool {
+ return e.ignoreStats
+}
+
+// Errors that can be returned by the network stack.
+var (
+ ErrUnknownProtocol = &Error{msg: "unknown protocol"}
+ ErrUnknownNICID = &Error{msg: "unknown nic id"}
+ ErrUnknownDevice = &Error{msg: "unknown device"}
+ ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
+ ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
+ ErrDuplicateAddress = &Error{msg: "duplicate address"}
+ ErrNoRoute = &Error{msg: "no route"}
+ ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
+ ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
+ ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
+ ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
+ ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
+ ErrNoPortAvailable = &Error{msg: "no ports are available"}
+ ErrPortInUse = &Error{msg: "port is in use"}
+ ErrBadLocalAddress = &Error{msg: "bad local address"}
+ ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
+ ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
+ ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
+ ErrConnectionRefused = &Error{msg: "connection was refused"}
+ ErrTimeout = &Error{msg: "operation timed out"}
+ ErrAborted = &Error{msg: "operation aborted"}
+ ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
+ ErrDestinationRequired = &Error{msg: "destination address is required"}
+ ErrNotSupported = &Error{msg: "operation not supported"}
+ ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
+ ErrNotConnected = &Error{msg: "endpoint not connected"}
+ ErrConnectionReset = &Error{msg: "connection reset by peer"}
+ ErrConnectionAborted = &Error{msg: "connection aborted"}
+ ErrNoSuchFile = &Error{msg: "no such file"}
+ ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
+ ErrNoLinkAddress = &Error{msg: "no remote link address"}
+ ErrBadAddress = &Error{msg: "bad address"}
+ ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
+ ErrMessageTooLong = &Error{msg: "message too long"}
+ ErrNoBufferSpace = &Error{msg: "no buffer space available"}
+ ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
+ ErrNotPermitted = &Error{msg: "operation not permitted"}
+)
+
+// Errors related to Subnet
+var (
+ errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
+ errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
+)
+
+// ErrSaveRejection indicates a failed save due to unsupported networking state.
+// This type of errors is only used for save logic.
+type ErrSaveRejection struct {
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported networking state: " + e.Err.Error()
+}
+
+// A Clock provides the current time.
+//
+// Times returned by a Clock should always be used for application-visible
+// time. Only monotonic times should be used for netstack internal timekeeping.
+type Clock interface {
+ // NowNanoseconds returns the current real time as a number of
+ // nanoseconds since the Unix epoch.
+ NowNanoseconds() int64
+
+ // NowMonotonic returns a monotonic time value.
+ NowMonotonic() int64
+}
+
+// Address is a byte slice cast as a string that represents the address of a
+// network node. Or, in the case of unix endpoints, it may represent a path.
+type Address string
+
+// AddressMask is a bitmask for an address.
+type AddressMask string
+
+// String implements Stringer.
+func (a AddressMask) String() string {
+ return Address(a).String()
+}
+
+// Subnet is a subnet defined by its address and mask.
+type Subnet struct {
+ address Address
+ mask AddressMask
+}
+
+// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
+func NewSubnet(a Address, m AddressMask) (Subnet, error) {
+ if len(a) != len(m) {
+ return Subnet{}, errSubnetLengthMismatch
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&^m[i] != 0 {
+ return Subnet{}, errSubnetAddressMasked
+ }
+ }
+ return Subnet{a, m}, nil
+}
+
+// Contains returns true iff the address is of the same length and matches the
+// subnet address and mask.
+func (s *Subnet) Contains(a Address) bool {
+ if len(a) != len(s.address) {
+ return false
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&s.mask[i] != s.address[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// ID returns the subnet ID.
+func (s *Subnet) ID() Address {
+ return s.address
+}
+
+// Bits returns the number of ones (network bits) and zeros (host bits) in the
+// subnet mask.
+func (s *Subnet) Bits() (ones int, zeros int) {
+ for _, b := range []byte(s.mask) {
+ for i := uint(0); i < 8; i++ {
+ if b&(1<<i) == 0 {
+ zeros++
+ } else {
+ ones++
+ }
+ }
+ }
+ return
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (s *Subnet) Prefix() int {
+ for i, b := range []byte(s.mask) {
+ for j := 7; j >= 0; j-- {
+ if b&(1<<uint(j)) == 0 {
+ return i*8 + 7 - j
+ }
+ }
+ }
+ return len(s.mask) * 8
+}
+
+// Mask returns the subnet mask.
+func (s *Subnet) Mask() AddressMask {
+ return s.mask
+}
+
+// NICID is a number that uniquely identifies a NIC.
+type NICID int32
+
+// ShutdownFlags represents flags that can be passed to the Shutdown() method
+// of the Endpoint interface.
+type ShutdownFlags int
+
+// Values of the flags that can be passed to the Shutdown() method. They can
+// be OR'ed together.
+const (
+ ShutdownRead ShutdownFlags = 1 << iota
+ ShutdownWrite
+)
+
+// FullAddress represents a full transport node address, as required by the
+// Connect() and Bind() methods.
+//
+// +stateify savable
+type FullAddress struct {
+ // NIC is the ID of the NIC this address refers to.
+ //
+ // This may not be used by all endpoint types.
+ NIC NICID
+
+ // Addr is the network address.
+ Addr Address
+
+ // Port is the transport port.
+ //
+ // This may not be used by all endpoint types.
+ Port uint16
+}
+
+// Payload provides an interface around data that is being sent to an endpoint.
+// This allows the endpoint to request the amount of data it needs based on
+// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
+type Payload interface {
+ // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
+ Get(size int) ([]byte, *Error)
+
+ // Size returns the payload size.
+ Size() int
+}
+
+// SlicePayload implements Payload on top of slices for convenience.
+type SlicePayload []byte
+
+// Get implements Payload.
+func (s SlicePayload) Get(size int) ([]byte, *Error) {
+ if size > s.Size() {
+ size = s.Size()
+ }
+ return s[:size], nil
+}
+
+// Size implements Payload.
+func (s SlicePayload) Size() int {
+ return len(s)
+}
+
+// A ControlMessages contains socket control messages for IP sockets.
+//
+// +stateify savable
+type ControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packed used to create
+ // the read data was received.
+ Timestamp int64
+}
+
+// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
+// that exposes functionality like read, write, connect, etc. to users of the
+// networking stack.
+type Endpoint interface {
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // Read reads data from the endpoint and optionally returns the sender.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+
+ // Write writes data to the endpoint's peer. This method does not block if
+ // the data cannot be written.
+ //
+ // Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes
+ // successfully written to the Endpoint. That is, if a call to
+ // Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and
+ // the caller should not use data[:n] after Write returns.
+ //
+ // Note that unlike io.Writer.Write, it is not an error for Write to
+ // perform a partial write (if n > 0, no error may be returned). Only
+ // stream (TCP) Endpoints may return partial writes, and even then only
+ // in the case where writing additional data would block. Other Endpoints
+ // will either write the entire message or return an error.
+ //
+ // For UDP and Ping sockets if address resolution is required,
+ // ErrNoLinkAddress and a notification channel is returned for the caller to
+ // block. Channel is closed once address resolution is complete (success or
+ // not). The channel is only non-nil in this case.
+ Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
+
+ // Peek reads data without consuming it from the endpoint.
+ //
+ // This method does not block if there is no data pending.
+ Peek([][]byte) (uintptr, ControlMessages, *Error)
+
+ // Connect connects the endpoint to its peer. Specifying a NIC is
+ // optional.
+ //
+ // There are three classes of return values:
+ // nil -- the attempt to connect succeeded.
+ // ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started
+ // but hasn't completed yet. In this case, the caller must call Connect
+ // or GetSockOpt(ErrorOption) when the endpoint becomes writable to
+ // get the actual result. The first call to Connect after the socket has
+ // connected returns nil. Calling connect again results in ErrAlreadyConnected.
+ // Anything else -- the attempt to connect failed.
+ Connect(address FullAddress) *Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags ShutdownFlags) *Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *waiter.Queue, *Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ Bind(address FullAddress) *Error
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (FullAddress, *Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (FullAddress, *Error)
+
+ // Readiness returns the current readiness of the endpoint. For example,
+ // if waiter.EventIn is set, the endpoint is immediately readable.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt sets a socket option. opt should be one of the *Option types.
+ SetSockOpt(opt interface{}) *Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // *Option types.
+ GetSockOpt(opt interface{}) *Error
+}
+
+// WriteOptions contains options for Endpoint.Write.
+type WriteOptions struct {
+ // If To is not nil, write to the given address instead of the endpoint's
+ // peer.
+ To *FullAddress
+
+ // More has the same semantics as Linux's MSG_MORE.
+ More bool
+
+ // EndOfRecord has the same semantics as Linux's MSG_EOR.
+ EndOfRecord bool
+}
+
+// ErrorOption is used in GetSockOpt to specify that the last error reported by
+// the endpoint should be cleared and returned.
+type ErrorOption struct{}
+
+// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
+// buffer size option.
+type SendBufferSizeOption int
+
+// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
+// receive buffer size option.
+type ReceiveBufferSizeOption int
+
+// SendQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the output buffer should be returned.
+type SendQueueSizeOption int
+
+// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the input buffer should be returned.
+type ReceiveQueueSizeOption int
+
+// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
+// socket is to be restricted to sending and receiving IPv6 packets only.
+type V6OnlyOption int
+
+// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be
+// sent out immediately by the transport protocol. For TCP, it determines if the
+// Nagle algorithm is on or off.
+type DelayOption int
+
+// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
+// held until segments are full by the TCP transport protocol.
+type CorkOption int
+
+// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
+// should allow reuse of local address.
+type ReuseAddressOption int
+
+// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
+// to be bound to an identical socket address.
+type ReusePortOption int
+
+// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
+type QuickAckOption int
+
+// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
+// SCM_CREDENTIALS socket control messages are enabled.
+//
+// Only supported on Unix sockets.
+type PasscredOption int
+
+// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
+//
+// TODO(b/64800844): Add and populate stat fields.
+type TCPInfoOption struct {
+ RTT time.Duration
+ RTTVar time.Duration
+}
+
+// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
+// TCP keepalive is enabled for this socket.
+type KeepaliveEnabledOption int
+
+// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
+// connection must remain idle before the first TCP keepalive packet is sent.
+// Once this time is reached, KeepaliveIntervalOption is used instead.
+type KeepaliveIdleOption time.Duration
+
+// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the
+// interval between sending TCP keepalive packets.
+type KeepaliveIntervalOption time.Duration
+
+// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
+// of un-ACKed TCP keepalives that will be sent before the connection is
+// closed.
+type KeepaliveCountOption int
+
+// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
+// TTL value for multicast messages. The default is 1.
+type MulticastTTLOption uint8
+
+// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
+// default interface for multicast.
+type MulticastInterfaceOption struct {
+ NIC NICID
+ InterfaceAddr Address
+}
+
+// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
+// multicast packets sent over a non-loopback interface will be looped back.
+type MulticastLoopOption bool
+
+// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
+// AddMembershipOption and RemoveMembershipOption.
+type MembershipOption struct {
+ NIC NICID
+ InterfaceAddr Address
+ MulticastAddr Address
+}
+
+// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type AddMembershipOption MembershipOption
+
+// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type RemoveMembershipOption MembershipOption
+
+// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
+// TCP out-of-band data is delivered along with the normal in-band data.
+type OutOfBandInlineOption int
+
+// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
+// datagram sockets are allowed to send packets to a broadcast address.
+type BroadcastOption int
+
+// Route is a row in the routing table. It specifies through which NIC (and
+// gateway) sets of packets should be routed. A row is considered viable if the
+// masked target address matches the destination adddress in the row.
+type Route struct {
+ // Destination is the address that must be matched against the masked
+ // target address to check if this row is viable.
+ Destination Address
+
+ // Mask specifies which bits of the Destination and the target address
+ // must match for this row to be viable.
+ Mask AddressMask
+
+ // Gateway is the gateway to be used if this row is viable.
+ Gateway Address
+
+ // NIC is the id of the nic to be used if this row is viable.
+ NIC NICID
+}
+
+// Match determines if r is viable for the given destination address.
+func (r *Route) Match(addr Address) bool {
+ if len(addr) != len(r.Destination) {
+ return false
+ }
+
+ // Using header.Ipv4Broadcast would introduce an import cycle, so
+ // we'll use a literal instead.
+ if addr == "\xff\xff\xff\xff" {
+ return true
+ }
+
+ for i := 0; i < len(r.Destination); i++ {
+ if (addr[i] & r.Mask[i]) != r.Destination[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// LinkEndpointID represents a data link layer endpoint.
+type LinkEndpointID uint64
+
+// TransportProtocolNumber is the number of a transport protocol.
+type TransportProtocolNumber uint32
+
+// NetworkProtocolNumber is the number of a network protocol.
+type NetworkProtocolNumber uint32
+
+// A StatCounter keeps track of a statistic.
+type StatCounter struct {
+ count uint64
+}
+
+// Increment adds one to the counter.
+func (s *StatCounter) Increment() {
+ s.IncrementBy(1)
+}
+
+// Value returns the current value of the counter.
+func (s *StatCounter) Value() uint64 {
+ return atomic.LoadUint64(&s.count)
+}
+
+// IncrementBy increments the counter by v.
+func (s *StatCounter) IncrementBy(v uint64) {
+ atomic.AddUint64(&s.count, v)
+}
+
+func (s *StatCounter) String() string {
+ return strconv.FormatUint(s.Value(), 10)
+}
+
+// ICMPv4PacketStats enumerates counts for all ICMPv4 packet types.
+type ICMPv4PacketStats struct {
+ // Echo is the total number of ICMPv4 echo packets counted.
+ Echo *StatCounter
+
+ // EchoReply is the total number of ICMPv4 echo reply packets counted.
+ EchoReply *StatCounter
+
+ // DstUnreachable is the total number of ICMPv4 destination unreachable
+ // packets counted.
+ DstUnreachable *StatCounter
+
+ // SrcQuench is the total number of ICMPv4 source quench packets
+ // counted.
+ SrcQuench *StatCounter
+
+ // Redirect is the total number of ICMPv4 redirect packets counted.
+ Redirect *StatCounter
+
+ // TimeExceeded is the total number of ICMPv4 time exceeded packets
+ // counted.
+ TimeExceeded *StatCounter
+
+ // ParamProblem is the total number of ICMPv4 parameter problem packets
+ // counted.
+ ParamProblem *StatCounter
+
+ // Timestamp is the total number of ICMPv4 timestamp packets counted.
+ Timestamp *StatCounter
+
+ // TimestampReply is the total number of ICMPv4 timestamp reply packets
+ // counted.
+ TimestampReply *StatCounter
+
+ // InfoRequest is the total number of ICMPv4 information request
+ // packets counted.
+ InfoRequest *StatCounter
+
+ // InfoReply is the total number of ICMPv4 information reply packets
+ // counted.
+ InfoReply *StatCounter
+}
+
+// ICMPv6PacketStats enumerates counts for all ICMPv6 packet types.
+type ICMPv6PacketStats struct {
+ // EchoRequest is the total number of ICMPv6 echo request packets
+ // counted.
+ EchoRequest *StatCounter
+
+ // EchoReply is the total number of ICMPv6 echo reply packets counted.
+ EchoReply *StatCounter
+
+ // DstUnreachable is the total number of ICMPv6 destination unreachable
+ // packets counted.
+ DstUnreachable *StatCounter
+
+ // PacketTooBig is the total number of ICMPv6 packet too big packets
+ // counted.
+ PacketTooBig *StatCounter
+
+ // TimeExceeded is the total number of ICMPv6 time exceeded packets
+ // counted.
+ TimeExceeded *StatCounter
+
+ // ParamProblem is the total number of ICMPv6 parameter problem packets
+ // counted.
+ ParamProblem *StatCounter
+
+ // RouterSolicit is the total number of ICMPv6 router solicit packets
+ // counted.
+ RouterSolicit *StatCounter
+
+ // RouterAdvert is the total number of ICMPv6 router advert packets
+ // counted.
+ RouterAdvert *StatCounter
+
+ // NeighborSolicit is the total number of ICMPv6 neighbor solicit
+ // packets counted.
+ NeighborSolicit *StatCounter
+
+ // NeighborAdvert is the total number of ICMPv6 neighbor advert packets
+ // counted.
+ NeighborAdvert *StatCounter
+
+ // RedirectMsg is the total number of ICMPv6 redirect message packets
+ // counted.
+ RedirectMsg *StatCounter
+}
+
+// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
+type ICMPv4SentPacketStats struct {
+ ICMPv4PacketStats
+
+ // Dropped is the total number of ICMPv4 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+}
+
+// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
+type ICMPv4ReceivedPacketStats struct {
+ ICMPv4PacketStats
+
+ // Invalid is the total number of ICMPv4 packets received that the
+ // transport layer could not parse.
+ Invalid *StatCounter
+}
+
+// ICMPv6SentPacketStats collects outbound ICMPv6-specific stats.
+type ICMPv6SentPacketStats struct {
+ ICMPv6PacketStats
+
+ // Dropped is the total number of ICMPv6 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+}
+
+// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
+type ICMPv6ReceivedPacketStats struct {
+ ICMPv6PacketStats
+
+ // Invalid is the total number of ICMPv6 packets received that the
+ // transport layer could not parse.
+ Invalid *StatCounter
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
+ // and a single count of packets which failed to write to the link
+ // layer.
+ V4PacketsSent ICMPv4SentPacketStats
+
+ // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
+ // packet type and a single count of invalid packets received.
+ V4PacketsReceived ICMPv4ReceivedPacketStats
+
+ // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
+ // and a single count of packets which failed to write to the link
+ // layer.
+ V6PacketsSent ICMPv6SentPacketStats
+
+ // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
+ // packet type and a single count of invalid packets received.
+ V6PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// IPStats collects IP-specific stats (both v4 and v6).
+type IPStats struct {
+ // PacketsReceived is the total number of IP packets received from the
+ // link layer in nic.DeliverNetworkPacket.
+ PacketsReceived *StatCounter
+
+ // InvalidAddressesReceived is the total number of IP packets received
+ // with an unknown or invalid destination address.
+ InvalidAddressesReceived *StatCounter
+
+ // PacketsDelivered is the total number of incoming IP packets that
+ // are successfully delivered to the transport layer via HandlePacket.
+ PacketsDelivered *StatCounter
+
+ // PacketsSent is the total number of IP packets sent via WritePacket.
+ PacketsSent *StatCounter
+
+ // OutgoingPacketErrors is the total number of IP packets which failed
+ // to write to a link-layer endpoint.
+ OutgoingPacketErrors *StatCounter
+}
+
+// TCPStats collects TCP-specific stats.
+type TCPStats struct {
+ // ActiveConnectionOpenings is the number of connections opened
+ // successfully via Connect.
+ ActiveConnectionOpenings *StatCounter
+
+ // PassiveConnectionOpenings is the number of connections opened
+ // successfully via Listen.
+ PassiveConnectionOpenings *StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop *StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop *StatCounter
+
+ // ListenOverflowCookieSent is the number of times a SYN cookie was sent.
+ ListenOverflowSynCookieSent *StatCounter
+
+ // ListenOverflowSynCookieRcvd is the number of times a valid SYN
+ // cookie was received.
+ ListenOverflowSynCookieRcvd *StatCounter
+
+ // ListenOverflowInvalidSynCookieRcvd is the number of times an invalid SYN cookie
+ // was received.
+ ListenOverflowInvalidSynCookieRcvd *StatCounter
+
+ // FailedConnectionAttempts is the number of calls to Connect or Listen
+ // (active and passive openings, respectively) that end in an error.
+ FailedConnectionAttempts *StatCounter
+
+ // ValidSegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ ValidSegmentsReceived *StatCounter
+
+ // InvalidSegmentsReceived is the number of TCP segments received that
+ // the transport layer could not parse.
+ InvalidSegmentsReceived *StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent *StatCounter
+
+ // ResetsSent is the number of TCP resets sent.
+ ResetsSent *StatCounter
+
+ // ResetsReceived is the number of TCP resets received.
+ ResetsReceived *StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits *StatCounter
+
+ // FastRecovery is the number of times Fast Recovery was used to
+ // recover from packet loss.
+ FastRecovery *StatCounter
+
+ // SACKRecovery is the number of times SACK Recovery was used to
+ // recover from packet loss.
+ SACKRecovery *StatCounter
+
+ // SlowStartRetransmits is the number of segments retransmitted in slow
+ // start.
+ SlowStartRetransmits *StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit *StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts *StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors *StatCounter
+}
+
+// UDPStats collects UDP-specific stats.
+type UDPStats struct {
+ // PacketsReceived is the number of UDP datagrams received via
+ // HandlePacket.
+ PacketsReceived *StatCounter
+
+ // UnknownPortErrors is the number of incoming UDP datagrams dropped
+ // because they did not have a known destination port.
+ UnknownPortErrors *StatCounter
+
+ // ReceiveBufferErrors is the number of incoming UDP datagrams dropped
+ // due to the receiving buffer being in an invalid state.
+ ReceiveBufferErrors *StatCounter
+
+ // MalformedPacketsReceived is the number of incoming UDP datagrams
+ // dropped due to the UDP header being in a malformed state.
+ MalformedPacketsReceived *StatCounter
+
+ // PacketsSent is the number of UDP datagrams sent via sendUDP.
+ PacketsSent *StatCounter
+}
+
+// Stats holds statistics about the networking stack.
+//
+// All fields are optional.
+type Stats struct {
+ // UnknownProtocolRcvdPackets is the number of packets received by the
+ // stack that were for an unknown or unsupported protocol.
+ UnknownProtocolRcvdPackets *StatCounter
+
+ // MalformedRcvPackets is the number of packets received by the stack
+ // that were deemed malformed.
+ MalformedRcvdPackets *StatCounter
+
+ // DroppedPackets is the number of packets dropped due to full queues.
+ DroppedPackets *StatCounter
+
+ // ICMP breaks out ICMP-specific stats (both v4 and v6).
+ ICMP ICMPStats
+
+ // IP breaks out IP-specific stats (both v4 and v6).
+ IP IPStats
+
+ // TCP breaks out TCP-specific stats.
+ TCP TCPStats
+
+ // UDP breaks out UDP-specific stats.
+ UDP UDPStats
+}
+
+func fillIn(v reflect.Value) {
+ for i := 0; i < v.NumField(); i++ {
+ v := v.Field(i)
+ switch v.Kind() {
+ case reflect.Ptr:
+ if s := v.Addr().Interface().(**StatCounter); *s == nil {
+ *s = &StatCounter{}
+ }
+ case reflect.Struct:
+ fillIn(v)
+ default:
+ panic(fmt.Sprintf("unexpected type %s", v.Type()))
+ }
+ }
+}
+
+// FillIn returns a copy of s with nil fields initialized to new StatCounters.
+func (s Stats) FillIn() Stats {
+ fillIn(reflect.ValueOf(&s).Elem())
+ return s
+}
+
+// String implements the fmt.Stringer interface.
+func (a Address) String() string {
+ switch len(a) {
+ case 4:
+ return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
+ case 16:
+ // Find the longest subsequence of hexadecimal zeros.
+ start, end := -1, -1
+ for i := 0; i < len(a); i += 2 {
+ j := i
+ for j < len(a) && a[j] == 0 && a[j+1] == 0 {
+ j += 2
+ }
+ if j > i+2 && j-i > end-start {
+ start, end = i, j
+ }
+ }
+
+ var b strings.Builder
+ for i := 0; i < len(a); i += 2 {
+ if i == start {
+ b.WriteString("::")
+ i = end
+ if end >= len(a) {
+ break
+ }
+ } else if i > 0 {
+ b.WriteByte(':')
+ }
+ v := uint16(a[i+0])<<8 | uint16(a[i+1])
+ if v == 0 {
+ b.WriteByte('0')
+ } else {
+ const digits = "0123456789abcdef"
+ for i := uint(3); i < 4; i-- {
+ if v := v >> (i * 4); v != 0 {
+ b.WriteByte(digits[v&0xf])
+ }
+ }
+ }
+ }
+ return b.String()
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// To4 converts the IPv4 address to a 4-byte representation.
+// If the address is not an IPv4 address, To4 returns "".
+func (a Address) To4() Address {
+ const (
+ ipv4len = 4
+ ipv6len = 16
+ )
+ if len(a) == ipv4len {
+ return a
+ }
+ if len(a) == ipv6len &&
+ isZeros(a[0:10]) &&
+ a[10] == 0xff &&
+ a[11] == 0xff {
+ return a[12:16]
+ }
+ return ""
+}
+
+// isZeros reports whether a is all zeros.
+func isZeros(a Address) bool {
+ for i := 0; i < len(a); i++ {
+ if a[i] != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// LinkAddress is a byte slice cast as a string that represents a link address.
+// It is typically a 6-byte MAC address.
+type LinkAddress string
+
+// String implements the fmt.Stringer interface.
+func (a LinkAddress) String() string {
+ switch len(a) {
+ case 6:
+ return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// ParseMACAddress parses an IEEE 802 address.
+//
+// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func ParseMACAddress(s string) (LinkAddress, error) {
+ parts := strings.FieldsFunc(s, func(c rune) bool {
+ return c == ':' || c == '-'
+ })
+ if len(parts) != 6 {
+ return "", fmt.Errorf("inconsistent parts: %s", s)
+ }
+ addr := make([]byte, 0, len(parts))
+ for _, part := range parts {
+ u, err := strconv.ParseUint(part, 16, 8)
+ if err != nil {
+ return "", fmt.Errorf("invalid hex digits: %s", s)
+ }
+ addr = append(addr, byte(u))
+ }
+ return LinkAddress(addr), nil
+}
+
+// ProtocolAddress is an address and the network protocol it is associated
+// with.
+type ProtocolAddress struct {
+ // Protocol is the protocol of the address.
+ Protocol NetworkProtocolNumber
+
+ // Address is a network address.
+ Address Address
+}
+
+// danglingEndpointsMu protects access to danglingEndpoints.
+var danglingEndpointsMu sync.Mutex
+
+// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+var danglingEndpoints = make(map[Endpoint]struct{})
+
+// GetDanglingEndpoints returns all dangling endpoints.
+func GetDanglingEndpoints() []Endpoint {
+ es := make([]Endpoint, 0, len(danglingEndpoints))
+ danglingEndpointsMu.Lock()
+ for e := range danglingEndpoints {
+ es = append(es, e)
+ }
+ danglingEndpointsMu.Unlock()
+ return es
+}
+
+// AddDanglingEndpoint adds a dangling endpoint.
+func AddDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ danglingEndpoints[e] = struct{}{}
+ danglingEndpointsMu.Unlock()
+}
+
+// DeleteDanglingEndpoint removes a dangling endpoint.
+func DeleteDanglingEndpoint(e Endpoint) {
+ danglingEndpointsMu.Lock()
+ delete(danglingEndpoints, e)
+ danglingEndpointsMu.Unlock()
+}
+
+// AsyncLoading is the global barrier for asynchronous endpoint loading
+// activities.
+var AsyncLoading sync.WaitGroup
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
new file mode 100755
index 000000000..3ed2e29f4
--- /dev/null
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -0,0 +1,40 @@
+// automatically generated by stateify.
+
+package tcpip
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *FullAddress) beforeSave() {}
+func (x *FullAddress) save(m state.Map) {
+ x.beforeSave()
+ m.Save("NIC", &x.NIC)
+ m.Save("Addr", &x.Addr)
+ m.Save("Port", &x.Port)
+}
+
+func (x *FullAddress) afterLoad() {}
+func (x *FullAddress) load(m state.Map) {
+ m.Load("NIC", &x.NIC)
+ m.Load("Addr", &x.Addr)
+ m.Load("Port", &x.Port)
+}
+
+func (x *ControlMessages) beforeSave() {}
+func (x *ControlMessages) save(m state.Map) {
+ x.beforeSave()
+ m.Save("HasTimestamp", &x.HasTimestamp)
+ m.Save("Timestamp", &x.Timestamp)
+}
+
+func (x *ControlMessages) afterLoad() {}
+func (x *ControlMessages) load(m state.Map) {
+ m.Load("HasTimestamp", &x.HasTimestamp)
+ m.Load("Timestamp", &x.Timestamp)
+}
+
+func init() {
+ state.Register("tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load})
+ state.Register("tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load})
+}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
new file mode 100644
index 000000000..a52262e87
--- /dev/null
+++ b/pkg/tcpip/time_unsafe.go
@@ -0,0 +1,45 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.9
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package tcpip
+
+import (
+ _ "time" // Used with go:linkname.
+ _ "unsafe" // Required for go:linkname.
+)
+
+// StdClock implements Clock with the time package.
+type StdClock struct{}
+
+var _ Clock = (*StdClock)(nil)
+
+//go:linkname now time.now
+func now() (sec int64, nsec int32, mono int64)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*StdClock) NowNanoseconds() int64 {
+ sec, nsec, _ := now()
+ return sec*1e9 + int64(nsec)
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (*StdClock) NowMonotonic() int64 {
+ _, _, mono := now()
+ return mono
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
new file mode 100644
index 000000000..e2b90ef10
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -0,0 +1,710 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "encoding/binary"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type icmpPacket struct {
+ icmpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ id stack.TransportEndpointID
+ state endpointState
+ // bindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ // regNICID is the default NIC to be used when callers don't specify a
+ // NIC.
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }, nil
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ if to == nil {
+ route = &e.route
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in Route.Resolve()
+ // call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ toCopy := *to
+ to = &toCopy
+ netProto, err := e.checkV4Mapped(to, true)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Find the enpoint.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ err = e.send4(route, v)
+
+ case header.IPv6ProtocolNumber:
+ err = send6(route, e.id.LocalPort, v)
+ }
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv4EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
+
+ hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ copy(icmpv4, data)
+ data = data[header.ICMPv4EchoMinimumSize:]
+
+ // Linux performs these basic checks.
+ if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv4.SetChecksum(0)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+}
+
+func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv6EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident. Sequence number is provided by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
+
+ hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(icmpv6, data)
+ data = data[header.ICMPv6EchoMinimumSize:]
+
+ if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv6.SetChecksum(0)
+ icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.regNICID = nicid
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.regNICID = addr.NIC
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ e.bindNICID = addr.NIC
+ e.bindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Only accept echo replies.
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ h := header.ICMPv4(vv.First())
+ if h.Type() != header.ICMPv4EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ case header.IPv6ProtocolNumber:
+ h := header.ICMPv6(vv.First())
+ if h.Type() != header.ICMPv6EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ }
+
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &icmpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ },
+ }
+
+ pkt.data = vv.Clone(pkt.views[:])
+
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += pkt.data.Size()
+
+ pkt.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
new file mode 100644
index 000000000..332b3cd33
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves icmpPacket.data field.
+func (p *icmpPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads icmpPacket.data field.
+func (p *icmpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go
new file mode 100755
index 000000000..1b35e5b4a
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go
@@ -0,0 +1,173 @@
+package icmp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type icmpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type icmpPacketList struct {
+ head *icmpPacket
+ tail *icmpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *icmpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *icmpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *icmpPacketList) Front() *icmpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *icmpPacketList) Back() *icmpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *icmpPacketList) PushFront(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *icmpPacketList) PushBack(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *icmpPacketList) PushBackList(m *icmpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) {
+ a := icmpPacketElementMapper{}.linkerFor(b).Next()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) {
+ b := icmpPacketElementMapper{}.linkerFor(a).Prev()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *icmpPacketList) Remove(e *icmpPacket) {
+ prev := icmpPacketElementMapper{}.linkerFor(e).Prev()
+ next := icmpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ icmpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type icmpPacketEntry struct {
+ next *icmpPacket
+ prev *icmpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *icmpPacketEntry) Next() *icmpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *icmpPacketEntry) Prev() *icmpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *icmpPacketEntry) SetNext(elem *icmpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
new file mode 100755
index 000000000..b66857348
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -0,0 +1,98 @@
+// automatically generated by stateify.
+
+package icmp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *icmpPacket) beforeSave() {}
+func (x *icmpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *icmpPacket) afterLoad() {}
+func (x *icmpPacket) load(m state.Map) {
+ m.Load("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("id", &x.id)
+ m.Save("state", &x.state)
+ m.Save("bindNICID", &x.bindNICID)
+ m.Save("bindAddr", &x.bindAddr)
+ m.Save("regNICID", &x.regNICID)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("id", &x.id)
+ m.Load("state", &x.state)
+ m.Load("bindNICID", &x.bindNICID)
+ m.Load("bindAddr", &x.bindAddr)
+ m.Load("regNICID", &x.regNICID)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *icmpPacketList) beforeSave() {}
+func (x *icmpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *icmpPacketList) afterLoad() {}
+func (x *icmpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *icmpPacketEntry) beforeSave() {}
+func (x *icmpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *icmpPacketEntry) afterLoad() {}
+func (x *icmpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load})
+ state.Register("icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load})
+ state.Register("icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load})
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
new file mode 100644
index 000000000..954fde9d8
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
+// protocols for use in ping. To use it in the networking stack, this package
+// must be added to the project, and
+// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
+// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
+// calling stack.New(). Then endpoints can be created by passing
+// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
+// when calling Stack.NewEndpoint().
+package icmp
+
+import (
+ "encoding/binary"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName4 is the string representation of the icmp protocol name.
+ ProtocolName4 = "icmp4"
+
+ // ProtocolNumber4 is the ICMP protocol number.
+ ProtocolNumber4 = header.ICMPv4ProtocolNumber
+
+ // ProtocolName6 is the string representation of the icmp protocol name.
+ ProtocolName6 = "icmp6"
+
+ // ProtocolNumber6 is the IPv6-ICMP protocol number.
+ ProtocolNumber6 = header.ICMPv6ProtocolNumber
+)
+
+// protocol implements stack.TransportProtocol.
+type protocol struct {
+ number tcpip.TransportProtocolNumber
+}
+
+// Number returns the ICMP protocol number.
+func (p *protocol) Number() tcpip.TransportProtocolNumber {
+ return p.number
+}
+
+func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.IPv4ProtocolNumber
+ case ProtocolNumber6:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid icmp packet size.
+func (p *protocol) MinimumPacketSize() int {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.ICMPv4EchoMinimumSize
+ case ProtocolNumber6:
+ return header.ICMPv6EchoMinimumSize
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// ParsePorts returns the source and destination ports stored in the given icmp
+// packet.
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ switch p.number {
+ case ProtocolNumber4:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil
+ case ProtocolNumber6:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+ })
+
+ stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
+ })
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
new file mode 100644
index 000000000..1daf5823f
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -0,0 +1,521 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package raw provides the implementation of raw sockets (see raw(7)). Raw
+// sockets allow applications to:
+//
+// * manually write and inspect transport layer headers and payloads
+// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * optionally write and inspect network layer and link layer headers for
+// packets
+//
+// Raw sockets don't have any notion of ports, and incoming packets are
+// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
+// receive every UDP packet received by netstack. bind(2) and connect(2) can be
+// used to filter incoming packets by source and destination.
+package raw
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is pre-allocated space to back data. As long as the packet is
+ // made up of fewer than 8 buffer.Views, no extra allocation is
+ // necessary to store packet data.
+ views [8]buffer.View `state:"nosave"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
+// have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ connected bool
+ bound bool
+ // registeredNIC is the NIC to which th endpoint is explicitly
+ // registered. Is set when Connect or Bind are used to specify a NIC.
+ registeredNIC tcpip.NICID
+ // boundNIC and boundAddr are set on calls to Bind(). When callers
+ // attempt actions that would invalidate the binding data (e.g. sending
+ // data via a NIC other than boundNIC), the endpoint will return an
+ // error.
+ boundNIC tcpip.NICID
+ boundAddr tcpip.Address
+ // route is the route to a remote network endpoint. It is set via
+ // Connect(), and is valid only when conneted is true.
+ route stack.Route `state:"manual"`
+}
+
+// NewEndpoint returns a raw endpoint for the given protocols.
+// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != header.IPv4ProtocolNumber {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ ep := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ return nil, err
+ }
+
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ if ep.connected {
+ ep.route.Release()
+ }
+
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ ep.mu.RLock()
+
+ if ep.closed {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Did the user caller provide a destination? If not, use the connected
+ // destination.
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if !ep.connected {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrDestinationRequired
+ }
+
+ if ep.route.IsResolutionRequired() {
+ savedRoute := &ep.route
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in finishWrite.
+ ep.mu.RUnlock()
+ ep.mu.Lock()
+
+ // Make sure that the route didn't change during the
+ // time we didn't hold the lock.
+ if !ep.connected || savedRoute != &ep.route {
+ ep.mu.Unlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ n, ch, err := ep.finishWrite(payload, savedRoute)
+ ep.mu.Unlock()
+ return n, ch, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &ep.route)
+ ep.mu.RUnlock()
+ return n, ch, err
+ }
+
+ // The caller provided a destination. Reject destination address if it
+ // goes through a different NIC than the endpoint was bound to.
+ nic := opts.To.NIC
+ if ep.bound && nic != 0 && nic != ep.boundNIC {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ // We don't support IPv6 yet, so this has to be an IPv4 address.
+ if len(opts.To.Addr) != header.IPv4AddressSize {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Find the route to the destination. If boundAddress is 0,
+ // FindRoute will choose an appropriate source address.
+ route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false)
+ if err != nil {
+ ep.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &route)
+ route.Release()
+ ep.mu.RUnlock()
+ return n, ch, err
+}
+
+// finishWrite writes the payload to a route. It resolves the route if
+// necessary. It's really just a helper to make defer unnecessary in Write.
+func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // We may need to resolve the route (match a link layer address to the
+ // network address). If that requires blocking (e.g. to use ARP),
+ // return a channel on which the caller can wait.
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ payloadBytes, err := payload.Get(payload.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch ep.netProto {
+ case header.IPv4ProtocolNumber:
+ hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
+ if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
+ return 0, nil, err
+ }
+
+ default:
+ return 0, nil, tcpip.ErrUnknownProtocol
+ }
+
+ return uintptr(len(payloadBytes)), nil, nil
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // We don't support IPv6 yet.
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nic := addr.NIC
+ if ep.bound {
+ if ep.boundNIC == 0 {
+ // If we're bound, but not to a specific NIC, the NIC
+ // in addr will be used. Nothing to do here.
+ } else if addr.NIC == 0 {
+ // If we're bound to a specific NIC, but addr doesn't
+ // specify a NIC, use the bound NIC.
+ nic = ep.boundNIC
+ } else if addr.NIC != ep.boundNIC {
+ // We're bound and addr specifies a NIC. They must be
+ // the same.
+ return tcpip.ErrInvalidEndpointState
+ }
+ }
+
+ // Find a route to the destination.
+ route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ // Save the route and NIC we've connected via.
+ ep.route = route.Clone()
+ ep.registeredNIC = nic
+ ep.connected = true
+
+ return nil
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if !ep.connected {
+ return tcpip.ErrNotConnected
+ }
+ return nil
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // Callers must provide an IPv4 address or no network address (for
+ // binding to a NIC, but not an address).
+ if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If a local address was specified, verify that it's valid.
+ if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.registeredNIC = addr.NIC
+ ep.boundNIC = addr.NIC
+ ep.boundAddr = addr.Addr
+ ep.bound = true
+
+ return nil
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
+ ep.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ ep.rcvMu.Lock()
+ if ep.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := ep.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ if ep.bound {
+ // If bound to a NIC, only accept data for that NIC.
+ if ep.boundNIC != 0 && ep.boundNIC != route.NICID() {
+ ep.rcvMu.Unlock()
+ return
+ }
+ // If bound to an address, only accept data for that address.
+ if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+ }
+
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if ep.connected && ep.route.RemoteAddress != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &packet{
+ senderAddr: tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ },
+ }
+
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ packet.data = combinedVV.Clone(packet.views[:])
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
new file mode 100644
index 000000000..e8907ebb1
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // If the endpoint is connected, re-connect via the save/restore stack.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ if err != nil {
+ panic(*err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind via the save/restore stack.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/packet_list.go b/pkg/tcpip/transport/raw/packet_list.go
new file mode 100755
index 000000000..2e9074934
--- /dev/null
+++ b/pkg/tcpip/transport/raw/packet_list.go
@@ -0,0 +1,173 @@
+package raw
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type packetElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (packetElementMapper) linkerFor(elem *packet) *packet { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type packetList struct {
+ head *packet
+ tail *packet
+}
+
+// Reset resets list l to the empty state.
+func (l *packetList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *packetList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *packetList) Front() *packet {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *packetList) Back() *packet {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *packetList) PushFront(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(l.head)
+ packetElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ packetElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *packetList) PushBack(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(nil)
+ packetElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *packetList) PushBackList(m *packetList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *packetList) InsertAfter(b, e *packet) {
+ a := packetElementMapper{}.linkerFor(b).Next()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *packetList) InsertBefore(a, e *packet) {
+ b := packetElementMapper{}.linkerFor(a).Prev()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *packetList) Remove(e *packet) {
+ prev := packetElementMapper{}.linkerFor(e).Prev()
+ next := packetElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ packetElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ packetElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type packetEntry struct {
+ next *packet
+ prev *packet
+}
+
+// Next returns the entry that follows e in the list.
+func (e *packetEntry) Next() *packet {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *packetEntry) Prev() *packet {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *packetEntry) SetNext(elem *packet) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *packetEntry) SetPrev(elem *packet) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
new file mode 100755
index 000000000..3327811b4
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -0,0 +1,96 @@
+// automatically generated by stateify.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *packet) beforeSave() {}
+func (x *packet) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("packetEntry", &x.packetEntry)
+ m.Save("timestampNS", &x.timestampNS)
+ m.Save("senderAddr", &x.senderAddr)
+}
+
+func (x *packet) afterLoad() {}
+func (x *packet) load(m state.Map) {
+ m.Load("packetEntry", &x.packetEntry)
+ m.Load("timestampNS", &x.timestampNS)
+ m.Load("senderAddr", &x.senderAddr)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("closed", &x.closed)
+ m.Save("connected", &x.connected)
+ m.Save("bound", &x.bound)
+ m.Save("registeredNIC", &x.registeredNIC)
+ m.Save("boundNIC", &x.boundNIC)
+ m.Save("boundAddr", &x.boundAddr)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("closed", &x.closed)
+ m.Load("connected", &x.connected)
+ m.Load("bound", &x.bound)
+ m.Load("registeredNIC", &x.registeredNIC)
+ m.Load("boundNIC", &x.boundNIC)
+ m.Load("boundAddr", &x.boundAddr)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *packetList) beforeSave() {}
+func (x *packetList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *packetList) afterLoad() {}
+func (x *packetList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *packetEntry) beforeSave() {}
+func (x *packetEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *packetEntry) afterLoad() {}
+func (x *packetEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("raw.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load})
+ state.Register("raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("raw.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load})
+ state.Register("raw.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
new file mode 100644
index 000000000..d4b860975
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -0,0 +1,499 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "crypto/sha1"
+ "encoding/binary"
+ "hash"
+ "io"
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // tsLen is the length, in bits, of the timestamp in the SYN cookie.
+ tsLen = 8
+
+ // tsMask is a mask for timestamp values (i.e., tsLen bits).
+ tsMask = (1 << tsLen) - 1
+
+ // tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
+ tsOffset = 24
+
+ // hashMask is the mask for hash values (i.e., tsOffset bits).
+ hashMask = (1 << tsOffset) - 1
+
+ // maxTSDiff is the maximum allowed difference between a received cookie
+ // timestamp and the current timestamp. If the difference is greater
+ // than maxTSDiff, the cookie is expired.
+ maxTSDiff = 2
+)
+
+var (
+ // SynRcvdCountThreshold is the global maximum number of connections
+ // that are allowed to be in SYN-RCVD state before TCP starts using SYN
+ // cookies to accept connections.
+ //
+ // It is an exported variable only for testing, and should not otherwise
+ // be used by importers of this package.
+ SynRcvdCountThreshold uint64 = 1000
+
+ // mssTable is a slice containing the possible MSS values that we
+ // encode in the SYN cookie with two bits.
+ mssTable = []uint16{536, 1300, 1440, 1460}
+)
+
+func encodeMSS(mss uint16) uint32 {
+ for i := len(mssTable) - 1; i > 0; i-- {
+ if mss >= mssTable[i] {
+ return uint32(i)
+ }
+ }
+ return 0
+}
+
+// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
+// protected by a mutex so that we can increment only when it's guaranteed not
+// to go above a threshold.
+var synRcvdCount struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+}
+
+// listenContext is used by a listening endpoint to store state used while
+// listening for connections. This struct is allocated by the listen goroutine
+// and must not be accessed or have its methods called concurrently as they
+// may mutate the stored objects.
+type listenContext struct {
+ stack *stack.Stack
+ rcvWnd seqnum.Size
+ nonce [2][sha1.BlockSize]byte
+ listenEP *endpoint
+
+ hasherMu sync.Mutex
+ hasher hash.Hash
+ v6only bool
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
+func timeStamp() uint32 {
+ return uint32(time.Now().Unix()>>6) & tsMask
+}
+
+// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
+// state. It succeeds if the increment doesn't make the count go beyond the
+// threshold, and fails otherwise.
+func incSynRcvdCount() bool {
+ synRcvdCount.Lock()
+
+ if synRcvdCount.value >= SynRcvdCountThreshold {
+ synRcvdCount.Unlock()
+ return false
+ }
+
+ synRcvdCount.pending.Add(1)
+ synRcvdCount.value++
+
+ synRcvdCount.Unlock()
+ return true
+}
+
+// decSynRcvdCount atomically decrements the global number of endpoints in
+// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
+// succeeded.
+func decSynRcvdCount() {
+ synRcvdCount.Lock()
+
+ synRcvdCount.value--
+ synRcvdCount.pending.Done()
+ synRcvdCount.Unlock()
+}
+
+// newListenContext creates a new listen context.
+func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+ l := &listenContext{
+ stack: stack,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ }
+
+ rand.Read(l.nonce[0][:])
+ rand.Read(l.nonce[1][:])
+
+ return l
+}
+
+// cookieHash calculates the cookieHash for the given id, timestamp and nonce
+// index. The hash is used to create and validate cookies.
+func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
+
+ // Initialize block with fixed-size data: local ports and v.
+ var payload [8]byte
+ binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
+ binary.BigEndian.PutUint32(payload[4:], ts)
+
+ // Feed everything to the hasher.
+ l.hasherMu.Lock()
+ l.hasher.Reset()
+ l.hasher.Write(payload[:])
+ l.hasher.Write(l.nonce[nonceIndex][:])
+ io.WriteString(l.hasher, string(id.LocalAddress))
+ io.WriteString(l.hasher, string(id.RemoteAddress))
+
+ // Finalize the calculation of the hash and return the first 4 bytes.
+ h := make([]byte, 0, sha1.Size)
+ h = l.hasher.Sum(h)
+ l.hasherMu.Unlock()
+
+ return binary.BigEndian.Uint32(h[:])
+}
+
+// createCookie creates a SYN cookie for the given id and incoming sequence
+// number.
+func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
+ ts := timeStamp()
+ v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
+ v += (l.cookieHash(id, ts, 1) + data) & hashMask
+ return seqnum.Value(v)
+}
+
+// isCookieValid checks if the supplied cookie is valid for the given id and
+// sequence number. If it is, it also returns the data originally encoded in the
+// cookie when createCookie was called.
+func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
+ ts := timeStamp()
+ v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
+ cookieTS := v >> tsOffset
+ if ((ts - cookieTS) & tsMask) > maxTSDiff {
+ return 0, false
+ }
+
+ return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
+}
+
+// createConnectingEndpoint creates a new endpoint in a connecting state, with
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create a new endpoint.
+ netProto := l.netProto
+ if netProto == 0 {
+ netProto = s.route.NetProto
+ }
+ n := newEndpoint(l.stack, netProto, nil)
+ n.v6only = l.v6only
+ n.id = s.id
+ n.boundNICID = s.route.NICID()
+ n.route = s.route.Clone()
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.rcvBufSize = int(l.rcvWnd)
+
+ n.maybeEnableTimestamp(rcvdSynOpts)
+ n.maybeEnableSACKPermitted(rcvdSynOpts)
+
+ n.initGSO()
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+ n.state = stateConnecting
+
+ // Create sender and receiver.
+ //
+ // The receiver at least temporarily has a zero receive window scale,
+ // but the caller may change it (before starting the protocol loop).
+ n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
+ n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
+
+ return n, nil
+}
+
+// createEndpoint creates a new endpoint in connected state and then performs
+// the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create new endpoint.
+ irs := s.sequenceNumber
+ cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
+ ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ // Perform the 3-way handshake.
+ h := newHandshake(ep, l.rcvWnd)
+
+ h.resetToSynRcvd(cookie, irs, opts, l.listenEP)
+ if err := h.execute(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.Close()
+ return nil, err
+ }
+
+ ep.state = stateConnected
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+
+ return ep, nil
+}
+
+// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
+// endpoint has transitioned out of the listen state, the new endpoint is closed
+// instead.
+func (e *endpoint) deliverAccepted(n *endpoint) {
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == stateListen {
+ e.acceptedChan <- n
+ e.waiterQueue.Notify(waiter.EventIn)
+ } else {
+ n.Close()
+ }
+}
+
+// handleSynSegment is called in its own goroutine once the listening endpoint
+// receives a SYN segment. It is responsible for completing the handshake and
+// queueing the new endpoint for acceptance.
+//
+// A limited number of these goroutines are allowed before TCP starts using SYN
+// cookies to accept connections.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
+ defer decSynRcvdCount()
+ defer e.decSynRcvdCount()
+ defer s.decRef()
+
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ e.deliverAccepted(n)
+}
+
+func (e *endpoint) incSynRcvdCount() bool {
+ e.mu.Lock()
+ log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount)
+ if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c {
+ e.mu.Unlock()
+ return false
+ }
+ e.synRcvdCount++
+ e.mu.Unlock()
+ return true
+}
+
+func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+}
+
+// handleListenSegment is called when a listening endpoint receives a segment
+// and needs to handle it.
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ switch s.flags {
+ case header.TCPFlagSyn:
+ opts := parseSynSegmentOptions(s)
+ if incSynRcvdCount() {
+ // Drop the SYN if the listen endpoint's accept queue is
+ // overflowing.
+ if e.incSynRcvdCount() {
+ log.Printf("processing syn packet")
+ s.incRef()
+ go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ return
+ }
+ log.Printf("dropping syn packet")
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ } else {
+ // TODO(bhaskerh): Increment syncookie sent stat.
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ // Send SYN with window scaling because we currently
+ // dont't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(timeStampOffset()),
+ TSEcr: opts.TSVal,
+ }
+ sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ }
+
+ case header.TCPFlagAck:
+ if len(e.acceptedChan) == cap(e.acceptedChan) {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+
+ // Validate the cookie.
+ data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ if !ok || int(data) >= len(mssTable) {
+ e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
+ // Create newly accepted endpoint and deliver it.
+ rcvdSynOptions := &header.TCPSynOptions{
+ MSS: mssTable[data],
+ // Disable Window scaling as original SYN is
+ // lost.
+ WS: -1,
+ }
+
+ // When syn cookies are in use we enable timestamp only
+ // if the ack specifies the timestamp option assuming
+ // that the other end did in fact negotiate the
+ // timestamp option in the original SYN.
+ if s.parsedOptions.TS {
+ rcvdSynOptions.TS = true
+ rcvdSynOptions.TSVal = s.parsedOptions.TSVal
+ rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
+ }
+
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ // clear the tsOffset for the newly created
+ // endpoint as the Timestamp was already
+ // randomly offset when the original SYN-ACK was
+ // sent above.
+ n.tsOffset = 0
+
+ // Switch state to connected.
+ n.state = stateConnected
+
+ // Do the delivery in a separate goroutine so
+ // that we don't block the listen loop in case
+ // the application is slow to accept or stops
+ // accepting.
+ //
+ // NOTE: This won't result in an unbounded
+ // number of goroutines as we do check before
+ // entering here that there was at least some
+ // space available in the backlog.
+ go e.deliverAccepted(n)
+ }
+}
+
+// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
+// its own goroutine and is responsible for handling connection requests.
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ defer func() {
+ // Mark endpoint as closed. This will prevent goroutines running
+ // handleSynSegment() from attempting to queue new connections
+ // to the endpoint.
+ e.mu.Lock()
+ e.state = stateClosed
+
+ // Do cleanup if needed.
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+ e.mu.Unlock()
+
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ }()
+
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.notificationWaker, wakerForNotification)
+ s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
+ for {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForNotification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+ synRcvdCount.pending.Wait()
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ case wakerForNewSegment:
+ // Process at most maxSegmentsPerWake segments.
+ mayRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ mayRequeue = false
+ break
+ }
+
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up
+ // in the next iteration.
+ if mayRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
new file mode 100644
index 000000000..2aed6f286
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -0,0 +1,1066 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// maxSegmentsPerWake is the maximum number of segments to process in the main
+// protocol goroutine per wake-up. Yielding [after this number of segments are
+// processed] allows other events to be processed as well (e.g., timeouts,
+// resets, etc.).
+const maxSegmentsPerWake = 100
+
+type handshakeState int
+
+// The following are the possible states of the TCP connection during a 3-way
+// handshake. A depiction of the states and transitions can be found in RFC 793,
+// page 23.
+const (
+ handshakeSynSent handshakeState = iota
+ handshakeSynRcvd
+ handshakeCompleted
+)
+
+// The following are used to set up sleepers.
+const (
+ wakerForNotification = iota
+ wakerForNewSegment
+ wakerForResend
+ wakerForResolution
+)
+
+const (
+ // Maximum space available for options.
+ maxOptionSize = 40
+)
+
+// handshake holds the state used during a TCP 3-way handshake.
+type handshake struct {
+ ep *endpoint
+ listenEP *endpoint // only non nil when doing passive connects.
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
+
+ // iss is the initial send sequence number, as defined in RFC 793.
+ iss seqnum.Value
+
+ // rcvWnd is the receive window, as defined in RFC 793.
+ rcvWnd seqnum.Size
+
+ // sndWnd is the send window, as defined in RFC 793.
+ sndWnd seqnum.Size
+
+ // mss is the maximum segment size received from the peer.
+ mss uint16
+
+ // sndWndScale is the send window scale, as defined in RFC 1323. A
+ // negative value means no scaling is supported by the peer.
+ sndWndScale int
+
+ // rcvWndScale is the receive window scale, as defined in RFC 1323.
+ rcvWndScale int
+}
+
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
+ h := handshake{
+ ep: ep,
+ active: true,
+ rcvWnd: rcvWnd,
+ rcvWndScale: FindWndScale(rcvWnd),
+ }
+ h.resetState()
+ return h
+}
+
+// FindWndScale determines the window scale to use for the given maximum window
+// size.
+func FindWndScale(wnd seqnum.Size) int {
+ if wnd < 0x10000 {
+ return 0
+ }
+
+ max := seqnum.Size(0xffff)
+ s := 0
+ for wnd > max && s < header.MaxWndScale {
+ s++
+ max <<= 1
+ }
+
+ return s
+}
+
+// resetState resets the state of the handshake object such that it becomes
+// ready for a new 3-way handshake.
+func (h *handshake) resetState() {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+
+ h.state = handshakeSynSent
+ h.flags = header.TCPFlagSyn
+ h.ackNum = 0
+ h.mss = 0
+ h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+}
+
+// effectiveRcvWndScale returns the effective receive window scale to be used.
+// If the peer doesn't support window scaling, the effective rcv wnd scale is
+// zero; otherwise it's the value calculated based on the initial rcv wnd.
+func (h *handshake) effectiveRcvWndScale() uint8 {
+ if h.sndWndScale < 0 {
+ return 0
+ }
+ return uint8(h.rcvWndScale)
+}
+
+// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
+// state.
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) {
+ h.active = false
+ h.state = handshakeSynRcvd
+ h.flags = header.TCPFlagSyn | header.TCPFlagAck
+ h.iss = iss
+ h.ackNum = irs + 1
+ h.mss = opts.MSS
+ h.sndWndScale = opts.WS
+ h.listenEP = listenEP
+}
+
+// checkAck checks if the ACK number, if present, of a segment received during
+// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
+// response.
+func (h *handshake) checkAck(s *segment) bool {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
+ // RFC 793, page 36, states that a reset must be generated when
+ // the connection is in any non-synchronized state and an
+ // incoming segment acknowledges something not yet sent. The
+ // connection remains in the same state.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, s.ackNumber, ack, 0)
+ return false
+ }
+
+ return true
+}
+
+// synSentState handles a segment received when the TCP 3-way handshake is in
+// the SYN-SENT state.
+func (h *handshake) synSentState(s *segment) *tcpip.Error {
+ // RFC 793, page 37, states that in the SYN-SENT state, a reset is
+ // acceptable if the ack field acknowledges the SYN.
+ if s.flagIsSet(header.TCPFlagRst) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ // We are in the SYN-SENT state. We only care about segments that have
+ // the SYN flag.
+ if !s.flagIsSet(header.TCPFlagSyn) {
+ return nil
+ }
+
+ // Parse the SYN options.
+ rcvSynOpts := parseSynSegmentOptions(s)
+
+ // Remember if the Timestamp option was negotiated.
+ h.ep.maybeEnableTimestamp(&rcvSynOpts)
+
+ // Remember if the SACKPermitted option was negotiated.
+ h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+
+ // Remember the sequence we'll ack from now on.
+ h.ackNum = s.sequenceNumber + 1
+ h.flags |= header.TCPFlagAck
+ h.mss = rcvSynOpts.MSS
+ h.sndWndScale = rcvSynOpts.WS
+
+ // If this is a SYN ACK response, we only need to acknowledge the SYN
+ // and the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ h.state = handshakeCompleted
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
+ return nil
+ }
+
+ // A SYN segment was received, but no ACK in it. We acknowledge the SYN
+ // but resend our own SYN and wait for it to be acknowledged in the
+ // SYN-RCVD state.
+ h.state = handshakeSynRcvd
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: rcvSynOpts.TS,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+
+ // We only send SACKPermitted if the other side indicated it
+ // permits SACK. This is not explicitly defined in the RFC but
+ // this is the behaviour implemented by Linux.
+ SACKPermitted: rcvSynOpts.SACKPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ return nil
+}
+
+// synRcvdState handles a segment received when the TCP 3-way handshake is in
+// the SYN-RCVD state.
+func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+ if s.flagIsSet(header.TCPFlagRst) {
+ // RFC 793, page 37, states that in the SYN-RCVD state, a reset
+ // is acceptable if the sequence number is in the window.
+ if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
+ // We received two SYN segments with different sequence
+ // numbers, so we reset this and restart the whole
+ // process, except that we don't reset the timer.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
+
+ if !h.active {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ h.resetState()
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: h.ep.sendTSOk,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: h.ep.sackPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ return nil
+ }
+
+ // We have previously received (and acknowledged) the peer's SYN. If the
+ // peer acknowledges our SYN, the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ // listenContext is also used by a tcp.Forwarder and in that
+ // context we do not have a listening endpoint to check the
+ // backlog. So skip this check if listenEP is nil.
+ if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) {
+ // If there is no space in the accept queue to accept
+ // this endpoint then silently drop this ACK. The peer
+ // will anyway resend the ack and we can complete the
+ // connection the next time it's retransmitted.
+ h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+ // If the timestamp option is negotiated and the segment does
+ // not carry a timestamp option then the segment must be dropped
+ // as per https://tools.ietf.org/html/rfc7323#section-3.2.
+ if h.ep.sendTSOk && !s.parsedOptions.TS {
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
+ // Update timestamp if required. See RFC7323, section-4.3.
+ if h.ep.sendTSOk && s.parsedOptions.TS {
+ h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
+ }
+ h.state = handshakeCompleted
+ return nil
+ }
+
+ return nil
+}
+
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
+// processSegments goes through the segment queue and processes up to
+// maxSegmentsPerWake (if they're available).
+func (h *handshake) processSegments() *tcpip.Error {
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := h.ep.segmentQueue.dequeue()
+ if s == nil {
+ return nil
+ }
+
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+
+ // We stop processing packets once the handshake is completed,
+ // otherwise we may process packets meant to be processed by
+ // the main protocol goroutine.
+ if h.state == handshakeCompleted {
+ break
+ }
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if !h.ep.segmentQueue.empty() {
+ h.ep.newSegmentWaker.Assert()
+ }
+
+ return nil
+}
+
+func (h *handshake) resolveRoute() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resolutionWaker := &sleep.Waker{}
+ s.AddWaker(resolutionWaker, wakerForResolution)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ defer s.Done()
+
+ // Initial action is to resolve route.
+ index := wakerForResolution
+ for {
+ switch index {
+ case wakerForResolution:
+ if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ // Either success (err == nil) or failure.
+ return err
+ }
+ // Resolution not completed. Keep trying...
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ h.ep.route.RemoveWaker(resolutionWaker)
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+ }
+
+ // Wait for notification.
+ index, _ = s.Fetch(true)
+ }
+}
+
+// execute executes the TCP 3-way handshake.
+func (h *handshake) execute() *tcpip.Error {
+ if h.ep.route.IsResolutionRequired() {
+ if err := h.resolveRoute(); err != nil {
+ return err
+ }
+ }
+
+ // Initialize the resend timer.
+ resendWaker := sleep.Waker{}
+ timeOut := time.Duration(time.Second)
+ rt := time.AfterFunc(timeOut, func() {
+ resendWaker.Assert()
+ })
+ defer rt.Stop()
+
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ var sackEnabled SACKEnabled
+ if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
+ // If stack returned an error when checking for SACKEnabled
+ // status then just default to switching off SACK negotiation.
+ sackEnabled = false
+ }
+
+ // Send the initial SYN segment and loop until the handshake is
+ // completed.
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: true,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: bool(sackEnabled),
+ }
+
+ // Execute is also called in a listen context so we want to make sure we
+ // only send the TS/SACK option when we received the TS/SACK in the
+ // initial SYN.
+ if h.state == handshakeSynRcvd {
+ synOpts.TS = h.ep.sendTSOk
+ synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ }
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ for h.state != handshakeCompleted {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForResend:
+ timeOut *= 2
+ if timeOut > 60*time.Second {
+ return tcpip.ErrTimeout
+ }
+ rt.Reset(timeOut)
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ for !h.ep.segmentQueue.empty() {
+ s := h.ep.segmentQueue.dequeue()
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+
+ case wakerForNewSegment:
+ if err := h.processSegments(); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
+ synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
+ if synOpts.TS {
+ s.parsedOptions.TSVal = synOpts.TSVal
+ s.parsedOptions.TSEcr = synOpts.TSEcr
+ }
+ return synOpts
+}
+
+var optionPool = sync.Pool{
+ New: func() interface{} {
+ return make([]byte, maxOptionSize)
+ },
+}
+
+func getOptions() []byte {
+ return optionPool.Get().([]byte)
+}
+
+func putOptions(options []byte) {
+ // Reslice to full capacity.
+ optionPool.Put(options[0:cap(options)])
+}
+
+func makeSynOptions(opts header.TCPSynOptions) []byte {
+ // Emulate linux option order. This is as follows:
+ //
+ // if md5: NOP NOP MD5SIG 18 md5sig(16)
+ // if mss: MSS 4 mss(2)
+ // if ts and sack_advertise:
+ // SACK 2 TIMESTAMP 2 timestamp(8)
+ // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
+ // elif sack: NOP NOP SACK 2
+ // if wscale: NOP WINDOW 3 ws(1)
+ // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
+ // [for each block] start_seq(4) end_seq(4)
+ // if fastopen_cookie:
+ // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
+ // else: FASTOPEN (2 + len(cookie))
+ // cookie(variable) [padding to four bytes]
+ //
+ options := getOptions()
+
+ // Always encode the mss.
+ offset := header.EncodeMSSOption(uint32(opts.MSS), options)
+
+ // Special ordering is required here. If both TS and SACK are enabled,
+ // then the SACK option precedes TS, with no padding. If they are
+ // enabled individually, then we see padding before the option.
+ if opts.TS && opts.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.TS {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.SACKPermitted {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ }
+
+ // Initialize the WS option.
+ if opts.WS >= 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeWSOption(opts.WS, options[offset:])
+ }
+
+ // Padding to the end; note that this never apply unless we add a
+ // fastopen option, we always expect the offset to remain the same.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+ // The MSS in opts is automatically calculated as this function is
+ // called from many places and we don't want every call point being
+ // embedded with the MSS calculation.
+ if opts.MSS == 0 {
+ opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
+ }
+
+ options := makeSynOptions(opts)
+ err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
+ putOptions(options)
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ // Allocate a buffer for the TCP header.
+ hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
+
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ // Initialize the header.
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ tcp.Encode(&header.TCPFields{
+ SrcPort: id.LocalPort,
+ DstPort: id.RemotePort,
+ SeqNum: uint32(seq),
+ AckNum: uint32(ack),
+ DataOffset: uint8(header.TCPMinimumSize + optLen),
+ Flags: flags,
+ WindowSize: uint16(rcvWnd),
+ })
+ copy(tcp[header.TCPMinimumSize:], opts)
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ // Only calculate the checksum if offloading isn't supported.
+ if gso != nil && gso.NeedsCsum {
+ // This is called CHECKSUM_PARTIAL in the Linux kernel. We
+ // calculate a checksum of the pseudo-header and save it in the
+ // TCP header, then the kernel calculate a checksum of the
+ // header and data and get the right sum of the TCP packet.
+ tcp.SetChecksum(xsum)
+ } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVV(data, xsum)
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
+ }
+
+ r.Stats().TCP.SegmentsSent.Increment()
+ if (flags & header.TCPFlagRst) != 0 {
+ r.Stats().TCP.ResetsSent.Increment()
+ }
+
+ return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl)
+}
+
+// makeOptions makes an options slice.
+func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
+ options := getOptions()
+ offset := 0
+
+ // N.B. the ordering here matches the ordering used by Linux internally
+ // and described in the raw makeOptions function. We don't include
+ // unnecessary cases here (post connection.)
+ if e.sendTSOk {
+ // Embed the timestamp if timestamp has been enabled.
+ //
+ // We only use the lower 32 bits of the unix time in
+ // milliseconds. This is similar to what Linux does where it
+ // uses the lower 32 bits of the jiffies value in the tsVal
+ // field of the timestamp option.
+ //
+ // Further, RFC7323 section-5.4 recommends millisecond
+ // resolution as the lowest recommended resolution for the
+ // timestamp clock.
+ //
+ // Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ }
+ if e.sackPermitted && len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ // We expect the above to produce an aligned offset.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+// sendRaw sends a TCP segment to the endpoint's peer.
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ var sackBlocks []header.SACKBlock
+ if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
+ }
+ options := e.makeOptions(sackBlocks)
+ err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso)
+ putOptions(options)
+ return err
+}
+
+func (e *endpoint) handleWrite() *tcpip.Error {
+ // Move packets from send queue to send list. The queue is accessible
+ // from other goroutines and protected by the send mutex, while the send
+ // list is only accessible from the handler goroutine, so it needs no
+ // mutexes.
+ e.sndBufMu.Lock()
+
+ first := e.sndQueue.Front()
+ if first != nil {
+ e.snd.writeList.PushBackList(&e.sndQueue)
+ e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
+ e.sndBufInQueue = 0
+ }
+
+ e.sndBufMu.Unlock()
+
+ // Initialize the next segment to write if it's currently nil.
+ if e.snd.writeNext == nil {
+ e.snd.writeNext = first
+ }
+
+ // Push out any new packets.
+ e.snd.sendData()
+
+ return nil
+}
+
+func (e *endpoint) handleClose() *tcpip.Error {
+ // Drain the send queue.
+ e.handleWrite()
+
+ // Mark send side as closed.
+ e.snd.closed = true
+
+ return nil
+}
+
+// resetConnectionLocked sends a RST segment and puts the endpoint in an error
+// state with the given error code. This method must only be called from the
+// protocol goroutine.
+func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+
+ e.state = stateError
+ e.hardError = err
+}
+
+// completeWorkerLocked is called by the worker goroutine when it's about to
+// exit. It marks the worker as completed and performs cleanup work if requested
+// by Close().
+func (e *endpoint) completeWorkerLocked() {
+ e.workerRunning = false
+ if e.workerCleanup {
+ e.cleanupLocked()
+ }
+}
+
+// handleSegments pulls segments from the queue and processes them. It returns
+// no error if the protocol loop should continue, an error otherwise.
+func (e *endpoint) handleSegments() *tcpip.Error {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ return tcpip.ErrConnectionReset
+ }
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ e.rcv.handleRcvdSegment(s)
+ e.snd.handleRcvdSegment(s)
+ }
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+
+ // Send an ACK for all processed packets if needed.
+ if e.rcv.rcvNxt != e.snd.maxSentAck {
+ e.snd.sendAck()
+ }
+
+ e.resetKeepaliveTimer(true)
+
+ return nil
+}
+
+// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
+// keepalive packets periodically when the connection is idle. If we don't hear
+// from the other side after a number of tries, we terminate the connection.
+func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.keepalive.Lock()
+ if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ e.keepalive.Unlock()
+ return nil
+ }
+
+ if e.keepalive.unacked >= e.keepalive.count {
+ e.keepalive.Unlock()
+ return tcpip.ErrConnectionReset
+ }
+
+ // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
+ // seg.seq = snd.nxt-1.
+ e.keepalive.unacked++
+ e.keepalive.Unlock()
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.resetKeepaliveTimer(false)
+ return nil
+}
+
+// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
+// whether it is enabled for this endpoint.
+func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
+ e.keepalive.Lock()
+ defer e.keepalive.Unlock()
+ if receivedData {
+ e.keepalive.unacked = 0
+ }
+ // Start the keepalive timer IFF it's enabled and there is no pending
+ // data to send.
+ if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ e.keepalive.timer.disable()
+ return
+ }
+ if e.keepalive.unacked > 0 {
+ e.keepalive.timer.enable(e.keepalive.interval)
+ } else {
+ e.keepalive.timer.enable(e.keepalive.idle)
+ }
+}
+
+// disableKeepaliveTimer stops the keepalive timer.
+func (e *endpoint) disableKeepaliveTimer() {
+ e.keepalive.Lock()
+ e.keepalive.timer.disable()
+ e.keepalive.Unlock()
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+ var closeTimer *time.Timer
+ var closeWaker sleep.Waker
+
+ epilogue := func() {
+ // e.mu is expected to be hold upon entering this section.
+
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ }
+
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+
+ e.mu.Unlock()
+
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+
+ if handshake {
+ // This is an active connection, so we must initiate the 3-way
+ // handshake, and then inform potential waiters about its
+ // completion.
+ h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ if err := h.execute(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = err
+ // Lock released below.
+ epilogue()
+
+ return err
+ }
+
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ e.rcvListMu.Unlock()
+ }
+
+ e.keepalive.timer.init(&e.keepalive.waker)
+ defer e.keepalive.timer.cleanup()
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.mu.Lock()
+ e.state = stateConnected
+ drained := e.drainDone != nil
+ e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ e.waiterQueue.Notify(waiter.EventOut)
+
+ // Set up the functions that will be called when the main protocol loop
+ // wakes up.
+ funcs := []struct {
+ w *sleep.Waker
+ f func() *tcpip.Error
+ }{
+ {
+ w: &e.sndWaker,
+ f: e.handleWrite,
+ },
+ {
+ w: &e.sndCloseWaker,
+ f: e.handleClose,
+ },
+ {
+ w: &e.newSegmentWaker,
+ f: e.handleSegments,
+ },
+ {
+ w: &closeWaker,
+ f: func() *tcpip.Error {
+ return tcpip.ErrConnectionAborted
+ },
+ },
+ {
+ w: &e.snd.resendWaker,
+ f: func() *tcpip.Error {
+ if !e.snd.retransmitTimerExpired() {
+ return tcpip.ErrTimeout
+ }
+ return nil
+ },
+ },
+ {
+ w: &e.keepalive.waker,
+ f: e.keepaliveTimerExpired,
+ },
+ {
+ w: &e.notificationWaker,
+ f: func() *tcpip.Error {
+ n := e.fetchNotifications()
+ if n&notifyNonZeroReceiveWindow != 0 {
+ e.rcv.nonZeroWindow()
+ }
+
+ if n&notifyReceiveWindowChanged != 0 {
+ e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
+ }
+
+ if n&notifyMTUChanged != 0 {
+ e.sndBufMu.Lock()
+ count := e.packetTooBigCount
+ e.packetTooBigCount = 0
+ mtu := e.sndMTU
+ e.sndBufMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n&notifyReset != 0 {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ }
+ if n&notifyClose != 0 && closeTimer == nil {
+ // Reset the connection 3 seconds after
+ // the endpoint has been closed.
+ //
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK as the loop here will not honor
+ // the firing until the undrain arrives.
+ closeTimer = time.AfterFunc(3*time.Second, func() {
+ closeWaker.Assert()
+ })
+ }
+
+ if n&notifyKeepaliveChanged != 0 {
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK. See above.
+ e.resetKeepaliveTimer(true)
+ }
+
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegments(); err != nil {
+ return err
+ }
+ }
+ if e.state != stateError {
+ close(e.drainDone)
+ <-e.undrain
+ }
+ }
+
+ return nil
+ },
+ },
+ }
+
+ // Initialize the sleeper based on the wakers in funcs.
+ s := sleep.Sleeper{}
+ for i := range funcs {
+ s.AddWaker(funcs[i].w, i)
+ }
+
+ // The following assertions and notifications are needed for restored
+ // endpoints. Fresh newly created endpoints have empty states and should
+ // not invoke any.
+ e.segmentQueue.mu.Lock()
+ if !e.segmentQueue.list.Empty() {
+ e.newSegmentWaker.Assert()
+ }
+ e.segmentQueue.mu.Unlock()
+
+ e.rcvListMu.Lock()
+ if !e.rcvList.Empty() {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ e.rcvListMu.Unlock()
+
+ e.mu.RLock()
+ if e.workerCleanup {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+ e.mu.RUnlock()
+
+ // Main loop. Handle segments until both send and receive ends of the
+ // connection have completed.
+ for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ if err := funcs[v].f(); err != nil {
+ e.mu.Lock()
+ e.resetConnectionLocked(err)
+ // Lock released below.
+ epilogue()
+
+ return nil
+ }
+ }
+
+ // Mark endpoint as closed.
+ e.mu.Lock()
+ if e.state != stateError {
+ e.state = stateClosed
+ }
+ // Lock released below.
+ epilogue()
+
+ return nil
+}
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
new file mode 100644
index 000000000..e618cd2b9
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -0,0 +1,233 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "time"
+)
+
+// cubicState stores the variables related to TCP CUBIC congestion
+// control algorithm state.
+//
+// See: https://tools.ietf.org/html/rfc8312.
+type cubicState struct {
+ // wLastMax is the previous wMax value.
+ wLastMax float64
+
+ // wMax is the value of the congestion window at the
+ // time of last congestion event.
+ wMax float64
+
+ // t denotes the time when the current congestion avoidance
+ // was entered.
+ t time.Time
+
+ // numCongestionEvents tracks the number of congestion events since last
+ // RTO.
+ numCongestionEvents int
+
+ // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
+ // per RFC.
+ c float64
+
+ // k is the time period that the above function takes to increase the
+ // current window size to W_max if there are no further congestion
+ // events and is calculated using the following equation:
+ //
+ // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
+ k float64
+
+ // beta is the CUBIC multiplication decrease factor. that is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // W_cubic(0)=W_max*beta_cubic.
+ beta float64
+
+ // wC is window computed by CUBIC at time t. It's calculated using the
+ // formula:
+ //
+ // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ wC float64
+
+ // wEst is the window computed by CUBIC at time t+RTT i.e
+ // W_cubic(t+RTT).
+ wEst float64
+
+ s *sender
+}
+
+// newCubicCC returns a partially initialized cubic state with the constants
+// beta and c set and t set to current time.
+func newCubicCC(s *sender) *cubicState {
+ return &cubicState{
+ t: time.Now(),
+ beta: 0.7,
+ c: 0.4,
+ s: s,
+ }
+}
+
+// enterCongestionAvoidance is used to initialize cubic in cases where we exit
+// SlowStart without a real congestion event taking place. This can happen when
+// a connection goes back to slow start due to a retransmit and we exceed the
+// previously lowered ssThresh without experiencing packet loss.
+//
+// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
+func (c *cubicState) enterCongestionAvoidance() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.7 &
+ // https://tools.ietf.org/html/rfc8312#section-4.8
+ if c.numCongestionEvents == 0 {
+ c.k = 0
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+ }
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window we cross
+// the ssThresh then it will return the number of packets that must be consumed
+// in congestion avoidance mode.
+func (c *cubicState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := c.s.sndCwnd + packetsAcked
+ enterCA := false
+ if newcwnd >= c.s.sndSsthresh {
+ newcwnd = c.s.sndSsthresh
+ c.s.sndCAAckCount = 0
+ enterCA = true
+ }
+
+ packetsAcked -= newcwnd - c.s.sndCwnd
+ c.s.sndCwnd = newcwnd
+ if enterCA {
+ c.enterCongestionAvoidance()
+ }
+ return packetsAcked
+}
+
+// Update updates cubic's internal state variables. It must be called on every
+// ACK received.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) Update(packetsAcked int) {
+ if c.s.sndCwnd < c.s.sndSsthresh {
+ packetsAcked = c.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ } else {
+ c.s.rtt.Lock()
+ srtt := c.s.rtt.srtt
+ c.s.rtt.Unlock()
+ c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ }
+}
+
+// cubicCwnd computes the CUBIC congestion window after t seconds from last
+// congestion event.
+func (c *cubicState) cubicCwnd(t float64) float64 {
+ return c.c*math.Pow(t, 3.0) + c.wMax
+}
+
+// getCwnd returns the current congestion window as computed by CUBIC.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
+ elapsed := time.Since(c.t).Seconds()
+
+ // Compute the window as per Cubic after 'elapsed' time
+ // since last congestion event.
+ c.wC = c.cubicCwnd(elapsed - c.k)
+
+ // Compute the TCP friendly estimate of the congestion window.
+ c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+
+ // Make sure in the TCP friendly region CUBIC performs at least
+ // as well as Reno.
+ if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ // TCP Friendly region of cubic.
+ return int(c.wEst)
+ }
+
+ // In Concave/Convex region of CUBIC, calculate what CUBIC window
+ // will be after 1 RTT and use that to grow congestion window
+ // for every ack.
+ tEst := (time.Since(c.t) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.k)
+ // As per 4.3 for each received ACK cwnd must be incremented
+ // by (w_cubic(t+RTT) - cwnd/cwnd.
+ cwnd := float64(sndCwnd)
+ for i := 0; i < packetsAcked; i++ {
+ // Concave/Convex regions of cubic have the same formulas.
+ // See: https://tools.ietf.org/html/rfc8312#section-4.3
+ cwnd += (wtRtt - cwnd) / cwnd
+ }
+ return int(cwnd)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (c *cubicState) HandleNDupAcks() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.5
+ c.numCongestionEvents++
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+ c.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
+func (c *cubicState) HandleRTOExpired() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.6
+ c.t = time.Now()
+ c.numCongestionEvents = 0
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+
+ // We lost a packet, so reduce ssthresh.
+ c.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ c.s.sndCwnd = 1
+}
+
+// fastConvergence implements the logic for Fast Convergence algorithm as
+// described in https://tools.ietf.org/html/rfc8312#section-4.6.
+func (c *cubicState) fastConvergence() {
+ if c.wMax < c.wLastMax {
+ c.wLastMax = c.wMax
+ c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ } else {
+ c.wLastMax = c.wMax
+ }
+ // Recompute k as wMax may have changed.
+ c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+}
+
+// PostRecovery implemements congestionControl.PostRecovery.
+func (c *cubicState) PostRecovery() {
+ c.t = time.Now()
+}
+
+// reduceSlowStartThreshold returns new SsThresh as described in
+// https://tools.ietf.org/html/rfc8312#section-4.7.
+func (c *cubicState) reduceSlowStartThreshold() {
+ c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
new file mode 100644
index 000000000..fd697402e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -0,0 +1,1741 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tmutex"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateListen
+ stateConnecting
+ stateConnected
+ stateClosed
+ stateError
+)
+
+// Reasons for notifying the protocol goroutine.
+const (
+ notifyNonZeroReceiveWindow = 1 << iota
+ notifyReceiveWindowChanged
+ notifyClose
+ notifyMTUChanged
+ notifyDrain
+ notifyReset
+ notifyKeepaliveChanged
+)
+
+// SACKInfo holds TCP SACK related information for a given endpoint.
+//
+// +stateify savable
+type SACKInfo struct {
+ // Blocks is the maximum number of SACK blocks we track
+ // per endpoint.
+ Blocks [MaxSACKBlocks]header.SACKBlock
+
+ // NumBlocks is the number of valid SACK blocks stored in the
+ // blocks array above.
+ NumBlocks int
+}
+
+// endpoint represents a TCP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized. The protocol implementation, however, runs in a single
+// goroutine.
+//
+// +stateify savable
+type endpoint struct {
+ // workMu is used to arbitrate which goroutine may perform protocol
+ // work. Only the main protocol goroutine is expected to call Lock() on
+ // it, but other goroutines (e.g., send) may call TryLock() to eagerly
+ // perform work without having to wait for the main one to wake up.
+ workMu tmutex.Mutex `state:"nosave"`
+
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue `state:"wait"`
+
+ // lastError represents the last error that the endpoint reported;
+ // access to it is protected by the following mutex.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
+
+ // The following fields are used to manage the receive queue. The
+ // protocol goroutine adds ready-for-delivery segments to rcvList,
+ // which are returned by Read() calls to users.
+ //
+ // Once the peer has closed its send side, rcvClosed is set to true
+ // to indicate to users that no more data is coming.
+ //
+ // rcvListMu can be taken after the endpoint mu below.
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+
+ // The following fields are protected by the mutex.
+ mu sync.RWMutex `state:"nosave"`
+ id stack.TransportEndpointID
+ state endpointState `state:".(endpointState)"`
+ isPortReserved bool `state:"manual"`
+ isRegistered bool
+ boundNICID tcpip.NICID `state:"manual"`
+ route stack.Route `state:"manual"`
+ v6only bool
+ isConnectNotified bool
+ // TCP should never broadcast but Linux nevertheless supports enabling/
+ // disabling SO_BROADCAST, albeit as a NOOP.
+ broadcast bool
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+
+ // hardError is meaningful only when state is stateError, it stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by mu.
+ hardError *tcpip.Error `state:".(string)"`
+
+ // workerRunning specifies if a worker goroutine is running.
+ workerRunning bool
+
+ // workerCleanup specifies if the worker goroutine must perform cleanup
+ // before exitting. This can only be set to true when workerRunning is
+ // also true, and they're both protected by the mutex.
+ workerCleanup bool
+
+ // sendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1
+ sendTSOk bool
+
+ // recentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ recentTS uint32
+
+ // tsOffset is a randomized offset added to the value of the
+ // TSVal field in the timestamp option.
+ tsOffset uint32
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // sackPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ sackPermitted bool
+
+ // sack holds TCP SACK related information for this endpoint.
+ sack SACKInfo
+
+ // reusePort is set to true if SO_REUSEPORT is enabled.
+ reusePort bool
+
+ // delay enables Nagle's algorithm.
+ //
+ // delay is a boolean (0 is false) and must be accessed atomically.
+ delay uint32
+
+ // cork holds back segments until full.
+ //
+ // cork is a boolean (0 is false) and must be accessed atomically.
+ cork uint32
+
+ // scoreboard holds TCP SACK Scoreboard information for this endpoint.
+ scoreboard *SACKScoreboard
+
+ // The options below aren't implemented, but we remember the user
+ // settings because applications expect to be able to set/query these
+ // options.
+ reuseAddr bool
+
+ // slowAck holds the negated state of quick ack. It is stubbed out and
+ // does nothing.
+ //
+ // slowAck is a boolean (0 is false) and must be accessed atomically.
+ slowAck uint32
+
+ // segmentQueue is used to hand received segments to the protocol
+ // goroutine. Segments are queued as long as the queue is not full,
+ // and dropped when it is.
+ segmentQueue segmentQueue `state:"wait"`
+
+ // synRcvdCount is the number of connections for this endpoint that are
+ // in SYN-RCVD state.
+ synRcvdCount int
+
+ // The following fields are used to manage the send buffer. When
+ // segments are ready to be sent, they are added to sndQueue and the
+ // protocol goroutine is signaled via sndWaker.
+ //
+ // When the send side is closed, the protocol goroutine is notified via
+ // sndCloseWaker, and sndClosed is set to true.
+ sndBufMu sync.Mutex `state:"nosave"`
+ sndBufSize int
+ sndBufUsed int
+ sndClosed bool
+ sndBufInQueue seqnum.Size
+ sndQueue segmentList `state:"wait"`
+ sndWaker sleep.Waker `state:"manual"`
+ sndCloseWaker sleep.Waker `state:"manual"`
+
+ // cc stores the name of the Congestion Control algorithm to use for
+ // this endpoint.
+ cc CongestionControlOption
+
+ // The following are used when a "packet too big" control packet is
+ // received. They are protected by sndBufMu. They are used to
+ // communicate to the main protocol goroutine how many such control
+ // messages have been received since the last notification was processed
+ // and what was the smallest MTU seen.
+ packetTooBigCount int
+ sndMTU int
+
+ // newSegmentWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and handle new segments queued to it.
+ newSegmentWaker sleep.Waker `state:"manual"`
+
+ // notificationWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and check for notifications.
+ notificationWaker sleep.Waker `state:"manual"`
+
+ // notifyFlags is a bitmask of flags used to indicate to the protocol
+ // goroutine what it was notified; this is only accessed atomically.
+ notifyFlags uint32 `state:"nosave"`
+
+ // keepalive manages TCP keepalive state. When the connection is idle
+ // (no data sent or received) for keepaliveIdle, we start sending
+ // keepalives every keepalive.interval. If we send keepalive.count
+ // without hearing a response, the connection is closed.
+ keepalive keepalive
+
+ // acceptedChan is used by a listening endpoint protocol goroutine to
+ // send newly accepted connections to the endpoint so that they can be
+ // read by Accept() calls.
+ acceptedChan chan *endpoint `state:".([]*endpoint)"`
+
+ // The following are only used from the protocol goroutine, and
+ // therefore don't need locks to protect them.
+ rcv *receiver `state:"wait"`
+ snd *sender `state:"wait"`
+
+ // The goroutine drain completion notification channel.
+ drainDone chan struct{} `state:"nosave"`
+
+ // The goroutine undrain notification channel.
+ undrain chan struct{} `state:"nosave"`
+
+ // probe if not nil is invoked on every received segment. It is passed
+ // a copy of the current state of the endpoint.
+ probe stack.TCPProbeFunc `state:"nosave"`
+
+ // The following are only used to assist the restore run to re-connect.
+ bindAddress tcpip.Address
+ connectingAddress tcpip.Address
+
+ gso *stack.GSO
+}
+
+// StopWork halts packet processing. Only to be used in tests.
+func (e *endpoint) StopWork() {
+ e.workMu.Lock()
+}
+
+// ResumeWork resumes packet processing. Only to be used in tests.
+func (e *endpoint) ResumeWork() {
+ e.workMu.Unlock()
+}
+
+// keepalive is a synchronization wrapper used to appease stateify. See the
+// comment in endpoint, where it is used.
+//
+// +stateify savable
+type keepalive struct {
+ sync.Mutex `state:"nosave"`
+ enabled bool
+ idle time.Duration
+ interval time.Duration
+ count int
+ unacked int
+ timer timer `state:"nosave"`
+ waker sleep.Waker `state:"nosave"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ e := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSize: DefaultBufferSize,
+ sndBufSize: DefaultBufferSize,
+ sndMTU: int(math.MaxInt32),
+ reuseAddr: true,
+ keepalive: keepalive{
+ // Linux defaults.
+ idle: 2 * time.Hour,
+ interval: 75 * time.Second,
+ count: 9,
+ },
+ }
+
+ var ss SendBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSize = ss.Default
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSize = rs.Default
+ }
+
+ var cs CongestionControlOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ e.cc = cs
+ }
+
+ if p := stack.GetTCPProbe(); p != nil {
+ e.probe = p
+ }
+
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+ e.workMu.Lock()
+ e.tsOffset = timeStampOffset()
+ return e
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ result := waiter.EventMask(0)
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ switch e.state {
+ case stateInitial, stateBound, stateConnecting:
+ // Ready for nothing.
+
+ case stateClosed, stateError:
+ // Ready for anything.
+ result = mask
+
+ case stateListen:
+ // Check if there's anything in the accepted channel.
+ if (mask & waiter.EventIn) != 0 {
+ if len(e.acceptedChan) > 0 {
+ result |= waiter.EventIn
+ }
+ }
+
+ case stateConnected:
+ // Determine if the endpoint is writable if requested.
+ if (mask & waiter.EventOut) != 0 {
+ e.sndBufMu.Lock()
+ if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ result |= waiter.EventOut
+ }
+ e.sndBufMu.Unlock()
+ }
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvListMu.Lock()
+ if e.rcvBufUsed > 0 || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvListMu.Unlock()
+ }
+ }
+
+ return result
+}
+
+func (e *endpoint) fetchNotifications() uint32 {
+ return atomic.SwapUint32(&e.notifyFlags, 0)
+}
+
+func (e *endpoint) notifyProtocolGoroutine(n uint32) {
+ for {
+ v := atomic.LoadUint32(&e.notifyFlags)
+ if v&n == n {
+ // The flags are already set.
+ return
+ }
+
+ if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) {
+ if v == 0 {
+ // We are causing a transition from no flags to
+ // at least one flag set, so we must cause the
+ // protocol goroutine to wake up.
+ e.notificationWaker.Assert()
+ }
+ return
+ }
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it. It must be called only once and with no other concurrent calls to
+// the endpoint.
+func (e *endpoint) Close() {
+ // Issue a shutdown so that the peer knows we won't send any more data
+ // if we're connected, or stop accepting if we're listening.
+ e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+
+ e.mu.Lock()
+
+ // For listening sockets, we always release ports inline so that they
+ // are immediately available for reuse after Close() is called. If also
+ // registered, we unregister as well otherwise the next user would fail
+ // in Listen() when trying to register.
+ if e.state == stateListen && e.isPortReserved {
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.isRegistered = false
+ }
+
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+ }
+
+ // Either perform the local cleanup or kick the worker to make sure it
+ // knows it needs to cleanup.
+ tcpip.AddDanglingEndpoint(e)
+ if !e.workerRunning {
+ e.cleanupLocked()
+ } else {
+ e.workerCleanup = true
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ e.mu.Unlock()
+}
+
+// cleanupLocked frees all resources associated with the endpoint. It is called
+// after Close() is called and the worker goroutine (if any) is done with its
+// work.
+func (e *endpoint) cleanupLocked() {
+ // Close all endpoints that might have been accepted by TCP but not by
+ // the client.
+ if e.acceptedChan != nil {
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ e.acceptedChan = nil
+ }
+ e.workerCleanup = false
+
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.isRegistered = false
+ }
+
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.route.Release()
+ tcpip.DeleteDanglingEndpoint(e)
+}
+
+// Read reads data from the endpoint.
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.mu.RLock()
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data. Also note that a RST being received
+ // would cause the state to become stateError so we should allow the
+ // reads to proceed before returning a ECONNRESET.
+ e.rcvListMu.Lock()
+ bufUsed := e.rcvBufUsed
+ if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
+ e.rcvListMu.Unlock()
+ he := e.hardError
+ e.mu.RUnlock()
+ if s == stateError {
+ return buffer.View{}, tcpip.ControlMessages{}, he
+ }
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ v, err := e.readLocked()
+ e.rcvListMu.Unlock()
+
+ e.mu.RUnlock()
+
+ return v, tcpip.ControlMessages{}, err
+}
+
+func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return buffer.View{}, tcpip.ErrClosedForReceive
+ }
+ return buffer.View{}, tcpip.ErrWouldBlock
+ }
+
+ s := e.rcvList.Front()
+ views := s.data.Views()
+ v := views[s.viewToDeliver]
+ s.viewToDeliver++
+
+ if s.viewToDeliver >= len(views) {
+ e.rcvList.Remove(s)
+ s.decRef()
+ }
+
+ scale := e.rcv.rcvWndScale
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufUsed -= len(v)
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ return v, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be written to if it's not connected.
+ if e.state != stateConnected {
+ switch e.state {
+ case stateError:
+ return 0, nil, e.hardError
+ default:
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+ }
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ e.sndBufMu.Lock()
+
+ // Check if the connection has already been closed for sends.
+ if e.sndClosed {
+ e.sndBufMu.Unlock()
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Check against the limit.
+ avail := e.sndBufSize - e.sndBufUsed
+ if avail <= 0 {
+ e.sndBufMu.Unlock()
+ return 0, nil, tcpip.ErrWouldBlock
+ }
+
+ v, perr := p.Get(avail)
+ if perr != nil {
+ e.sndBufMu.Unlock()
+ return 0, nil, perr
+ }
+
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
+
+ // Add data to the send queue.
+ e.sndBufUsed += l
+ e.sndBufInQueue += seqnum.Size(l)
+ e.sndQueue.PushBack(s)
+
+ e.sndBufMu.Unlock()
+
+ if e.workMu.TryLock() {
+ // Do the work inline.
+ e.handleWrite()
+ e.workMu.Unlock()
+ } else {
+ // Let the protocol goroutine do the work.
+ e.sndWaker.Assert()
+ }
+ return uintptr(l), nil, nil
+}
+
+// Peek reads data without consuming it from the endpoint.
+//
+// This method does not block if there is no data pending.
+func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.state; s != stateConnected && s != stateClosed {
+ if s == stateError {
+ return 0, tcpip.ControlMessages{}, e.hardError
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
+ // Make a copy of vec so we can modify the slide headers.
+ vec = append([][]byte(nil), vec...)
+
+ var num uintptr
+
+ for s := e.rcvList.Front(); s != nil; s = s.Next() {
+ views := s.data.Views()
+
+ for i := s.viewToDeliver; i < len(views); i++ {
+ v := views[i]
+
+ for len(v) > 0 {
+ if len(vec) == 0 {
+ return num, tcpip.ControlMessages{}, nil
+ }
+ if len(vec[0]) == 0 {
+ vec = vec[1:]
+ continue
+ }
+
+ n := copy(vec[0], v)
+ v = v[n:]
+ vec[0] = vec[0][n:]
+ num += uintptr(n)
+ }
+ }
+ }
+
+ return num, tcpip.ControlMessages{}, nil
+}
+
+// zeroReceiveWindow checks if the receive window to be announced now would be
+// zero, based on the amount of available buffer and the receive window scaling.
+//
+// It must be called with rcvListMu held.
+func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
+ if e.rcvBufUsed >= e.rcvBufSize {
+ return true
+ }
+
+ return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs ReceiveBufferSizeOption
+ size := int(v)
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if size < rs.Min {
+ size = rs.Min
+ }
+ if size > rs.Max {
+ size = rs.Max
+ }
+ }
+
+ mask := uint32(notifyReceiveWindowChanged)
+
+ e.rcvListMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.rcvWndScale
+ }
+ if size>>scale == 0 {
+ size = 1 << scale
+ }
+
+ // Make sure 2*size doesn't overflow.
+ if size > math.MaxInt32/2 {
+ size = math.MaxInt32 / 2
+ }
+
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufSize = size
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.rcvListMu.Unlock()
+
+ e.notifyProtocolGoroutine(mask)
+ return nil
+
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ size := int(v)
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if size < ss.Min {
+ size = ss.Min
+ }
+ if size > ss.Max {
+ size = ss.Max
+ }
+ }
+
+ e.sndBufMu.Lock()
+ e.sndBufSize = size
+ e.sndBufMu.Unlock()
+ return nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+ return nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v != 0
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ e.keepalive.idle = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ e.keepalive.interval = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = int(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v != 0
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return nil
+ }
+}
+
+// readyReceiveSize returns the number of bytes ready to be received.
+func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be in listen state.
+ if e.state == stateListen {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ return e.rcvBufUsed, nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ e.lastErrorMu.Lock()
+ err := e.lastError
+ e.lastError = nil
+ e.lastErrorMu.Unlock()
+ return err
+
+ case *tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.sndBufMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
+ e.rcvListMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ v, err := e.readyReceiveSize()
+ if err != nil {
+ return err
+ }
+
+ *o = tcpip.ReceiveQueueSizeOption(v)
+ return nil
+
+ case *tcpip.DelayOption:
+ *o = 0
+ if v := atomic.LoadUint32(&e.delay); v != 0 {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.CorkOption:
+ *o = 0
+ if v := atomic.LoadUint32(&e.cork); v != 0 {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReuseAddressOption:
+ e.mu.RLock()
+ v := e.reuseAddr
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.QuickAckOption:
+ *o = 1
+ if v := atomic.LoadUint32(&e.slowAck); v != 0 {
+ *o = 0
+ }
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.TCPInfoOption:
+ *o = tcpip.TCPInfoOption{}
+ e.mu.RLock()
+ snd := e.snd
+ e.mu.RUnlock()
+ if snd != nil {
+ snd.rtt.Lock()
+ o.RTT = snd.rtt.srtt
+ o.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+ }
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveCountOption(e.keepalive.count)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
+ *o = 1
+ return nil
+
+ case *tcpip.BroadcastOption:
+ e.mu.Lock()
+ v := e.broadcast
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return e.connect(addr, true, true)
+}
+
+// connect connects the endpoint to its peer. In the normal non-S/R case, the
+// new connection is expected to run the main goroutine and perform handshake.
+// In restore of previously connected endpoints, both ends will be passively
+// created (so no new handshaking is done); for stack-accepted connections not
+// yet accepted by the app, they are restored without running the main goroutine
+// here.
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ defer func() {
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ }
+ }()
+
+ connectingAddr := addr.Addr
+
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ nicid := addr.NIC
+ switch e.state {
+ case stateBound:
+ // If we're already bound to a NIC but the caller is requesting
+ // that we use a different one now, we cannot proceed.
+ if e.boundNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.boundNICID {
+ return tcpip.ErrNoRoute
+ }
+
+ nicid = e.boundNICID
+
+ case stateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID
+ // (if any) when we find a route.
+
+ case stateConnecting:
+ // A connection request has already been issued but hasn't
+ // completed yet.
+ return tcpip.ErrAlreadyConnecting
+
+ case stateConnected:
+ // The endpoint is already connected. If caller hasn't been notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
+ return tcpip.ErrAlreadyConnected
+
+ case stateError:
+ return e.hardError
+
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ origID := e.id
+
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ e.id.LocalAddress = r.LocalAddress
+ e.id.RemoteAddress = r.RemoteAddress
+ e.id.RemotePort = addr.Port
+
+ if e.id.LocalPort != 0 {
+ // The endpoint is bound to a port, attempt to register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ if err != nil {
+ return err
+ }
+ } else {
+ // The endpoint doesn't have a local port yet, so try to get
+ // one. Make sure that it isn't one that will result in the same
+ // address/port for both local and remote (otherwise this
+ // endpoint would be trying to connect to itself).
+ sameAddr := e.id.LocalAddress == e.id.RemoteAddress
+ if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.id.RemotePort {
+ return false, nil
+ }
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ return false, nil
+ }
+
+ id := e.id
+ id.LocalPort = p
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ case nil:
+ e.id = id
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ }); err != nil {
+ return err
+ }
+ }
+
+ // Remove the port reservation. This can happen when Bind is called
+ // before Connect: in such a case we don't want to hold on to
+ // reservations anymore.
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.isRegistered = true
+ e.state = stateConnecting
+ e.route = r.Clone()
+ e.boundNICID = nicid
+ e.effectiveNetProtos = netProtos
+ e.connectingAddress = connectingAddr
+
+ e.initGSO()
+
+ // Connect in the restore phase does not perform handshake. Restore its
+ // connection setting here.
+ if !handshake {
+ e.segmentQueue.mu.Lock()
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for s := l.Front(); s != nil; s = s.Next() {
+ s.id = e.id
+ s.route = r.Clone()
+ e.sndWaker.Assert()
+ }
+ }
+ e.segmentQueue.mu.Unlock()
+ e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
+ e.state = stateConnected
+ }
+
+ if run {
+ e.workerRunning = true
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ }
+
+ return tcpip.ErrConnectStarted
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ switch e.state {
+ case stateConnected:
+ // Close for read.
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // Mark read side as closed.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // If we're fully closed and we have unread data we need to abort
+ // the connection with a RST.
+ if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
+ }
+ }
+
+ // Close for write.
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ e.sndBufMu.Lock()
+
+ if e.sndClosed {
+ // Already closed.
+ e.sndBufMu.Unlock()
+ break
+ }
+
+ // Queue fin segment.
+ s := newSegmentFromView(&e.route, e.id, nil)
+ e.sndQueue.PushBack(s)
+ e.sndBufInQueue++
+
+ // Mark endpoint as closed.
+ e.sndClosed = true
+
+ e.sndBufMu.Unlock()
+
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+
+ case stateListen:
+ // Tell protocolListenLoop to stop.
+ if flags&tcpip.ShutdownRead != 0 {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ default:
+ return tcpip.ErrNotConnected
+ }
+
+ return nil
+}
+
+// Listen puts the endpoint in "listen" mode, which allows it to accept
+// new connections.
+func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ defer func() {
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ }
+ }()
+
+ // Allow the backlog to be adjusted if the endpoint is not shutting down.
+ // When the endpoint shuts down, it sets workerCleanup to true, and from
+ // that point onward, acceptedChan is the responsibility of the cleanup()
+ // method (and should not be touched anywhere else, including here).
+ if e.state == stateListen && !e.workerCleanup {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+
+ // Endpoint must be bound before it can transition to listen mode.
+ if e.state != stateBound {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Register the endpoint.
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ return err
+ }
+
+ e.isRegistered = true
+ e.state = stateListen
+ if e.acceptedChan == nil {
+ e.acceptedChan = make(chan *endpoint, backlog)
+ }
+ e.workerRunning = true
+
+ go e.protocolListenLoop( // S/R-SAFE: drained on save.
+ seqnum.Size(e.receiveBufferAvailable()))
+
+ return nil
+}
+
+// startAcceptedLoop sets up required state and starts a goroutine with the
+// main loop for accepted connections.
+func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.waiterQueue = waiterQueue
+ e.workerRunning = true
+ go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+}
+
+// Accept returns a new endpoint if a peer has established a connection
+// to an endpoint previously set to listen mode.
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // Endpoint must be in listen state before it can accept connections.
+ if e.state != stateListen {
+ return nil, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Get the new accepted endpoint.
+ var n *endpoint
+ select {
+ case n = <-e.acceptedChan:
+ default:
+ return nil, nil, tcpip.ErrWouldBlock
+ }
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ return n, wq, nil
+}
+
+// Bind binds the endpoint to a specific local port and optionally address.
+func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore. This is because once the endpoint goes into a connected or
+ // listen state, it is already bound.
+ if e.state != stateInitial {
+ return tcpip.ErrAlreadyBound
+ }
+
+ e.bindAddress = addr.Addr
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ if err != nil {
+ return err
+ }
+
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.id.LocalPort = port
+
+ // Any failures beyond this point must remove the port registration.
+ defer func() {
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.isPortReserved = false
+ e.effectiveNetProtos = nil
+ e.id.LocalPort = 0
+ e.id.LocalAddress = ""
+ e.boundNICID = 0
+ }
+ }()
+
+ // If an address is specified, we must ensure that it's one of our
+ // local addresses.
+ if len(addr.Addr) != 0 {
+ nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ e.boundNICID = nic
+ e.id.LocalAddress = addr.Addr
+ }
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ s := newSegment(r, id, vv)
+ if !s.parse() {
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ e.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ // Send packet to worker goroutine.
+ if e.segmentQueue.enqueue(s) {
+ e.newSegmentWaker.Assert()
+ } else {
+ // The queue is full, so we drop the segment.
+ e.stack.Stats().DroppedPackets.Increment()
+ s.decRef()
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ switch typ {
+ case stack.ControlPacketTooBig:
+ e.sndBufMu.Lock()
+ e.packetTooBigCount++
+ if v := int(extra); v < e.sndMTU {
+ e.sndMTU = v
+ }
+ e.sndBufMu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyMTUChanged)
+ }
+}
+
+// updateSndBufferUsage is called by the protocol goroutine when room opens up
+// in the send buffer. The number of newly available bytes is v.
+func (e *endpoint) updateSndBufferUsage(v int) {
+ e.sndBufMu.Lock()
+ notify := e.sndBufUsed >= e.sndBufSize>>1
+ e.sndBufUsed -= v
+ // We only notify when there is half the sndBufSize available after
+ // a full buffer event occurs. This ensures that we don't wake up
+ // writers to queue just 1-2 segments and go back to sleep.
+ notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ e.sndBufMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.EventOut)
+ }
+}
+
+// readyToRead is called by the protocol goroutine when a new segment is ready
+// to be read, or when the connection is closed for receiving (in which case
+// s will be nil).
+func (e *endpoint) readyToRead(s *segment) {
+ e.rcvListMu.Lock()
+ if s != nil {
+ s.incRef()
+ e.rcvBufUsed += s.data.Size()
+ e.rcvList.PushBack(s)
+ } else {
+ e.rcvClosed = true
+ }
+ e.rcvListMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventIn)
+}
+
+// receiveBufferAvailable calculates how many bytes are still available in the
+// receive buffer.
+func (e *endpoint) receiveBufferAvailable() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // We may use more bytes than the buffer size when the receive buffer
+ // shrinks.
+ if used >= size {
+ return 0
+ }
+
+ return size - used
+}
+
+func (e *endpoint) receiveBufferSize() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ e.rcvListMu.Unlock()
+
+ return size
+}
+
+// updateRecentTimestamp updates the recent timestamp using the algorithm
+// described in https://tools.ietf.org/html/rfc7323#section-4.3
+func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
+ if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.recentTS = tsVal
+ }
+}
+
+// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
+// the SYN options indicate that timestamp option was negotiated. It also
+// initializes the recentTS with the value provided in synOpts.TSval.
+func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+ if synOpts.TS {
+ e.sendTSOk = true
+ e.recentTS = synOpts.TSVal
+ }
+}
+
+// timestamp returns the timestamp value to be used in the TSVal field of the
+// timestamp option for outgoing TCP segments for a given endpoint.
+func (e *endpoint) timestamp() uint32 {
+ return tcpTimeStamp(e.tsOffset)
+}
+
+// tcpTimeStamp returns a timestamp offset by the provided offset. This is
+// not inlined above as it's used when SYN cookies are in use and endpoint
+// is not created at the time when the SYN cookie is sent.
+func tcpTimeStamp(offset uint32) uint32 {
+ now := time.Now()
+ return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+}
+
+// timeStampOffset returns a randomized timestamp offset to be used when sending
+// timestamp values in a timestamp option for a TCP segment.
+func timeStampOffset() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // NOTE: This is not completely to spec as normally this should be
+ // initialized in a manner analogous to how sequence numbers are
+ // randomized per connection basis. But for now this is sufficient.
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
+// if the SYN options indicate that the SACK option was negotiated and the TCP
+// stack is configured to enable TCP SACK option.
+func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+ var v SACKEnabled
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return
+ }
+ if bool(v) && synOpts.SACKPermitted {
+ e.sackPermitted = true
+ }
+}
+
+// maxOptionSize return the maximum size of TCP options.
+func (e *endpoint) maxOptionSize() (size int) {
+ var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
+ options := e.makeOptions(maxSackBlocks[:])
+ size = len(options)
+ putOptions(options)
+
+ return size
+}
+
+// completeState makes a full copy of the endpoint and returns it. This is used
+// before invoking the probe. The state returned may not be fully consistent if
+// there are intervening syscalls when the state is being copied.
+func (e *endpoint) completeState() stack.TCPEndpointState {
+ var s stack.TCPEndpointState
+ s.SegTime = time.Now()
+
+ // Copy EndpointID.
+ e.mu.Lock()
+ s.ID = stack.TCPEndpointID(e.id)
+ e.mu.Unlock()
+
+ // Copy endpoint rcv state.
+ e.rcvListMu.Lock()
+ s.RcvBufSize = e.rcvBufSize
+ s.RcvBufUsed = e.rcvBufUsed
+ s.RcvClosed = e.rcvClosed
+ e.rcvListMu.Unlock()
+
+ // Endpoint TCP Option state.
+ s.SendTSOk = e.sendTSOk
+ s.RecentTS = e.recentTS
+ s.TSOffset = e.tsOffset
+ s.SACKPermitted = e.sackPermitted
+ s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
+ copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
+ s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
+
+ // Copy endpoint send state.
+ e.sndBufMu.Lock()
+ s.SndBufSize = e.sndBufSize
+ s.SndBufUsed = e.sndBufUsed
+ s.SndClosed = e.sndClosed
+ s.SndBufInQueue = e.sndBufInQueue
+ s.PacketTooBigCount = e.packetTooBigCount
+ s.SndMTU = e.sndMTU
+ e.sndBufMu.Unlock()
+
+ // Copy receiver state.
+ s.Receiver = stack.TCPReceiverState{
+ RcvNxt: e.rcv.rcvNxt,
+ RcvAcc: e.rcv.rcvAcc,
+ RcvWndScale: e.rcv.rcvWndScale,
+ PendingBufUsed: e.rcv.pendingBufUsed,
+ PendingBufSize: e.rcv.pendingBufSize,
+ }
+
+ // Copy sender state.
+ s.Sender = stack.TCPSenderState{
+ LastSendTime: e.snd.lastSendTime,
+ DupAckCount: e.snd.dupAckCount,
+ FastRecovery: stack.TCPFastRecoveryState{
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ HighRxt: e.snd.fr.highRxt,
+ RescueRxt: e.snd.fr.rescueRxt,
+ },
+ SndCwnd: e.snd.sndCwnd,
+ Ssthresh: e.snd.sndSsthresh,
+ SndCAAckCount: e.snd.sndCAAckCount,
+ Outstanding: e.snd.outstanding,
+ SndWnd: e.snd.sndWnd,
+ SndUna: e.snd.sndUna,
+ SndNxt: e.snd.sndNxt,
+ RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
+ RTTMeasureTime: e.snd.rttMeasureTime,
+ Closed: e.snd.closed,
+ RTO: e.snd.rto,
+ SRTTInited: e.snd.srttInited,
+ MaxPayloadSize: e.snd.maxPayloadSize,
+ SndWndScale: e.snd.sndWndScale,
+ MaxSentAck: e.snd.maxSentAck,
+ }
+ e.snd.rtt.Lock()
+ s.Sender.SRTT = e.snd.rtt.srtt
+ e.snd.rtt.Unlock()
+
+ if cubic, ok := e.snd.cc.(*cubicState); ok {
+ s.Sender.Cubic = stack.TCPCubicState{
+ WMax: cubic.wMax,
+ WLastMax: cubic.wLastMax,
+ T: cubic.t,
+ TimeSinceLastCongestion: time.Since(cubic.t),
+ C: cubic.c,
+ K: cubic.k,
+ Beta: cubic.beta,
+ WC: cubic.wC,
+ WEst: cubic.wEst,
+ }
+ }
+ return s
+}
+
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityGSO == 0 {
+ return
+ }
+
+ gso := &stack.GSO{}
+ switch e.route.NetProto {
+ case header.IPv4ProtocolNumber:
+ gso.Type = stack.GSOTCPv4
+ gso.L3HdrLen = header.IPv4MinimumSize
+ case header.IPv6ProtocolNumber:
+ gso.Type = stack.GSOTCPv6
+ gso.L3HdrLen = header.IPv6MinimumSize
+ default:
+ panic(fmt.Sprintf("Unknown netProto: %v", e.netProto))
+ }
+ gso.NeedsCsum = true
+ gso.CsumOffset = header.TCPChecksumOffset
+ gso.MaxSize = e.route.GSOMaxSize()
+ e.gso = gso
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
new file mode 100644
index 000000000..e8aed2875
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -0,0 +1,362 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+func (e *endpoint) drainSegmentLocked() {
+ // Drain only up to once.
+ if e.drainDone != nil {
+ return
+ }
+
+ e.drainDone = make(chan struct{})
+ e.undrain = make(chan struct{})
+ e.mu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyDrain)
+ <-e.drainDone
+
+ e.mu.Lock()
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets.
+ e.segmentQueue.setLimit(0)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch e.state {
+ case stateInitial, stateBound:
+ case stateConnected:
+ if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
+ if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ }
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ e.Close()
+ e.mu.Lock()
+ }
+ if !e.workerRunning {
+ // The endpoint must be in acceptedChan or has been just
+ // disconnected and closed.
+ break
+ }
+ fallthrough
+ case stateListen, stateConnecting:
+ e.drainSegmentLocked()
+ if e.state != stateClosed && e.state != stateError {
+ if !e.workerRunning {
+ panic("endpoint has no worker running in listen, connecting, or connected state")
+ }
+ break
+ }
+ fallthrough
+ case stateError, stateClosed:
+ for e.state == stateError && e.workerRunning {
+ e.mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ e.mu.Lock()
+ }
+ if e.workerRunning {
+ panic("endpoint still has worker running in closed or error state")
+ }
+ default:
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ }
+
+ if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
+ panic("endpoint still has waiters upon save")
+ }
+
+ if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) {
+ panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
+ }
+}
+
+// saveAcceptedChan is invoked by stateify.
+func (e *endpoint) saveAcceptedChan() []*endpoint {
+ if e.acceptedChan == nil {
+ return nil
+ }
+ acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case ep := <-e.acceptedChan:
+ acceptedEndpoints[i] = ep
+ default:
+ panic("endpoint acceptedChan buffer got consumed by background context")
+ }
+ }
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case e.acceptedChan <- acceptedEndpoints[i]:
+ default:
+ panic("endpoint acceptedChan buffer got populated by background context")
+ }
+ }
+ return acceptedEndpoints
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
+ if cap(acceptedEndpoints) > 0 {
+ e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
+ for _, ep := range acceptedEndpoints {
+ e.acceptedChan <- ep
+ }
+ }
+}
+
+// saveState is invoked by stateify.
+func (e *endpoint) saveState() endpointState {
+ return e.state
+}
+
+// Endpoint loading must be done in the following ordering by their state, to
+// avoid dangling connecting w/o listening peer, and to avoid conflicts in port
+// reservation.
+var connectedLoading sync.WaitGroup
+var listenLoading sync.WaitGroup
+var connectingLoading sync.WaitGroup
+
+// Bound endpoint loading happens last.
+
+// loadState is invoked by stateify.
+func (e *endpoint) loadState(state endpointState) {
+ // This is to ensure that the loading wait groups include all applicable
+ // endpoints before any asynchronous calls to the Wait() methods.
+ switch state {
+ case stateConnected:
+ connectedLoading.Add(1)
+ case stateListen:
+ listenLoading.Add(1)
+ case stateConnecting:
+ connectingLoading.Add(1)
+ }
+ e.state = state
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+
+ state := e.state
+ switch state {
+ case stateInitial, stateBound, stateListen, stateConnecting, stateConnected:
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+ }
+ }
+
+ bind := func() {
+ e.state = stateInitial
+ if len(e.bindAddress) == 0 {
+ e.bindAddress = e.id.LocalAddress
+ }
+ if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateConnected:
+ bind()
+ if len(e.connectingAddress) == 0 {
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ } else {
+ e.connectingAddress = e.id.RemoteAddress
+ }
+ }
+ // Reset the scoreboard to reinitialize the sack information as
+ // we do not restore SACK information.
+ e.scoreboard.Reset()
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectedLoading.Done()
+ case stateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateConnecting:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateClosed:
+ if e.isPortReserved {
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ e.state = stateClosed
+ tcpip.AsyncLoading.Done()
+ }()
+ }
+ fallthrough
+ case stateError:
+ tcpip.DeleteDanglingEndpoint(e)
+ }
+}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = loadError(s)
+}
+
+// saveHardError is invoked by stateify.
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
+ return ""
+ }
+
+ return e.hardError.String()
+}
+
+// loadHardError is invoked by stateify.
+func (e *endpoint) loadHardError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.hardError = loadError(s)
+}
+
+var messageToError map[string]*tcpip.Error
+
+var populate sync.Once
+
+func loadError(s string) *tcpip.Error {
+ populate.Do(func() {
+ var errors = []*tcpip.Error{
+ tcpip.ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID,
+ tcpip.ErrUnknownDevice,
+ tcpip.ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress,
+ tcpip.ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable,
+ tcpip.ErrPortInUse,
+ tcpip.ErrBadLocalAddress,
+ tcpip.ErrClosedForSend,
+ tcpip.ErrClosedForReceive,
+ tcpip.ErrWouldBlock,
+ tcpip.ErrConnectionRefused,
+ tcpip.ErrTimeout,
+ tcpip.ErrAborted,
+ tcpip.ErrConnectStarted,
+ tcpip.ErrDestinationRequired,
+ tcpip.ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected,
+ tcpip.ErrConnectionReset,
+ tcpip.ErrConnectionAborted,
+ tcpip.ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress,
+ tcpip.ErrBadAddress,
+ tcpip.ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace,
+ tcpip.ErrBroadcastDisabled,
+ tcpip.ErrNotPermitted,
+ }
+
+ messageToError = make(map[string]*tcpip.Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
new file mode 100644
index 000000000..c30b45c2c
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a connection request forwarder, which allows clients to decide
+// what to do with a connection request, for example: ignore it, send a RST, or
+// attempt to complete the 3-way handshake.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ maxInFlight int
+ handler func(*ForwarderRequest)
+
+ mu sync.Mutex
+ inFlight map[stack.TransportEndpointID]struct{}
+ listen *listenContext
+}
+
+// NewForwarder allocates and initializes a new forwarder with the given
+// maximum number of in-flight connection attempts. Once the maximum is reached
+// new incoming connection requests will be ignored.
+//
+// If rcvWnd is set to zero, the default buffer size is used instead.
+func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
+ if rcvWnd == 0 {
+ rcvWnd = DefaultBufferSize
+ }
+ return &Forwarder{
+ maxInFlight: maxInFlight,
+ handler: handler,
+ inFlight: make(map[stack.TransportEndpointID]struct{}),
+ listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ }
+}
+
+// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
+// it's a SYN packet), returning true if it's the case. Otherwise the packet
+// is not handled and false is returned.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ // We only care about well-formed SYN packets.
+ if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ return false
+ }
+
+ opts := parseSynSegmentOptions(s)
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // We have an inflight request for this id, ignore this one for now.
+ if _, ok := f.inFlight[id]; ok {
+ return true
+ }
+
+ // Ignore the segment if we're beyond the limit.
+ if len(f.inFlight) >= f.maxInFlight {
+ return true
+ }
+
+ // Launch a new goroutine to handle the request.
+ f.inFlight[id] = struct{}{}
+ s.incRef()
+ go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry.
+ forwarder: f,
+ segment: s,
+ synOptions: opts,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a connection request received by the forwarder
+// and passed to the client. Clients must eventually call Complete() on it, and
+// may optionally create an endpoint to represent it via CreateEndpoint.
+type ForwarderRequest struct {
+ mu sync.Mutex
+ forwarder *Forwarder
+ segment *segment
+ synOptions header.TCPSynOptions
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the connection request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.segment.id
+}
+
+// Complete completes the request, and optionally sends a RST segment back to the
+// sender.
+func (r *ForwarderRequest) Complete(sendReset bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ panic("Completing already completed forwarder request")
+ }
+
+ // Remove request from the forwarder.
+ r.forwarder.mu.Lock()
+ delete(r.forwarder.inFlight, r.segment.id)
+ r.forwarder.mu.Unlock()
+
+ // If the caller requested, send a reset.
+ if sendReset {
+ replyWithReset(r.segment)
+ }
+
+ // Release all resources.
+ r.segment.decRef()
+ r.segment = nil
+ r.forwarder = nil
+}
+
+// CreateEndpoint creates a TCP endpoint for the connection request, performing
+// the 3-way handshake in the process.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ f := r.forwarder
+ ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ MSS: r.synOptions.MSS,
+ WS: r.synOptions.WS,
+ TS: r.synOptions.TS,
+ TSVal: r.synOptions.TSVal,
+ TSEcr: r.synOptions.TSEcr,
+ SACKPermitted: r.synOptions.SACKPermitted,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Start the protocol goroutine.
+ ep.startAcceptedLoop(queue)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
new file mode 100644
index 000000000..b31bcccfa
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains the implementation of the TCP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing tcp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package tcp
+
+import (
+ "strings"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the tcp protocol name.
+ ProtocolName = "tcp"
+
+ // ProtocolNumber is the tcp protocol number.
+ ProtocolNumber = header.TCPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ minBufferSize = 4 << 10 // 4096 bytes.
+
+ // DefaultBufferSize is the default size of the receive and send buffers.
+ DefaultBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive and send buffer can grow to.
+ maxBufferSize = 4 << 20 // 4MB
+
+ // MaxUnprocessedSegments is the maximum number of unprocessed segments
+ // that can be queued for a given endpoint.
+ MaxUnprocessedSegments = 300
+)
+
+// SACKEnabled option can be used to enable SACK support in the TCP
+// protocol. See: https://tools.ietf.org/html/rfc2018.
+type SACKEnabled bool
+
+// SendBufferSizeOption allows the default, min and max send buffer sizes for
+// TCP endpoints to be queried or configured.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption allows the default, min and max receive buffer size
+// for TCP endpoints to be queried or configured.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// CongestionControlOption sets the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption returns the supported congestion control
+// algorithms.
+type AvailableCongestionControlOption string
+
+type protocol struct {
+ mu sync.Mutex
+ sackEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+ congestionControl string
+ availableCongestionControl []string
+ allowedCongestionControl []string
+}
+
+// Number returns the tcp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new tcp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid tcp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.TCPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given tcp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.TCP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+//
+// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then
+// a reset is sent in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected by this
+// means."
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ if !s.parse() || !s.csumValid {
+ return false
+ }
+
+ // There's nothing to do if this is already a reset packet.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return true
+ }
+
+ replyWithReset(s)
+ return true
+}
+
+// replyWithReset replies to the given segment with a reset segment.
+func replyWithReset(s *segment) {
+ // Get the seqnum from the packet if the ack flag is set.
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+
+ ack := s.sequenceNumber.Add(s.logicalLen())
+
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case SendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case CongestionControlOption:
+ for _, c := range p.availableCongestionControl {
+ if string(v) == c {
+ p.mu.Lock()
+ p.congestionControl = string(v)
+ p.mu.Unlock()
+ return nil
+ }
+ }
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SACKEnabled:
+ p.mu.Lock()
+ *v = SACKEnabled(p.sackEnabled)
+ p.mu.Unlock()
+ return nil
+
+ case *SendBufferSizeOption:
+ p.mu.Lock()
+ *v = p.sendBufferSize
+ p.mu.Unlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ p.mu.Lock()
+ *v = p.recvBufferSize
+ p.mu.Unlock()
+ return nil
+ case *CongestionControlOption:
+ p.mu.Lock()
+ *v = CongestionControlOption(p.congestionControl)
+ p.mu.Unlock()
+ return nil
+ case *AvailableCongestionControlOption:
+ p.mu.Lock()
+ *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ p.mu.Unlock()
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
+ })
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
new file mode 100644
index 000000000..b08a0e356
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "container/heap"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// receiver holds the state necessary to receive TCP segments and turn them
+// into a stream of bytes.
+//
+// +stateify savable
+type receiver struct {
+ ep *endpoint
+
+ rcvNxt seqnum.Value
+
+ // rcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to the
+ // its peer that it's willing to accept. This may be different than
+ // rcvNxt + rcvWnd if the receive window is reduced; in that case we
+ // have to reduce the window as we receive more data instead of
+ // shrinking it.
+ rcvAcc seqnum.Value
+
+ rcvWndScale uint8
+
+ closed bool
+
+ pendingRcvdSegments segmentHeap
+ pendingBufUsed seqnum.Size
+ pendingBufSize seqnum.Size
+}
+
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
+ return &receiver{
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: rcvWnd,
+ }
+}
+
+// acceptable checks if the segment sequence number range is acceptable
+// according to the table on page 26 of RFC 793.
+func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ rcvWnd := r.rcvNxt.Size(r.rcvAcc)
+ if rcvWnd == 0 {
+ return segLen == 0 && segSeq == r.rcvNxt
+ }
+
+ return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
+ seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+}
+
+// getSendParams returns the parameters needed by the sender when building
+// segments to send.
+func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+ // Calculate the window size based on the current buffer size.
+ n := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(n))
+ if r.rcvAcc.LessThan(acc) {
+ r.rcvAcc = acc
+ }
+
+ return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
+}
+
+// nonZeroWindow is called when the receive window grows from zero to nonzero;
+// in such cases we may need to send an ack to indicate to our peer that it can
+// resume sending data.
+func (r *receiver) nonZeroWindow() {
+ if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
+ // We never got around to announcing a zero window size, so we
+ // don't need to immediately announce a nonzero one.
+ return
+ }
+
+ // Immediately send an ack.
+ r.ep.snd.sendAck()
+}
+
+// consumeSegment attempts to consume a segment that was received by r. The
+// segment may have just been received or may have been received earlier but
+// wasn't ready to be consumed then.
+//
+// Returns true if the segment was consumed, false if it cannot be consumed
+// yet because of a missing segment.
+func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
+ if segLen > 0 {
+ // If the segment doesn't include the seqnum we're expecting to
+ // consume now, we're missing a segment. We cannot proceed until
+ // we receive that segment though.
+ if !r.rcvNxt.InWindow(segSeq, segLen) {
+ return false
+ }
+
+ // Trim segment to eliminate already acknowledged data.
+ if segSeq.LessThan(r.rcvNxt) {
+ diff := segSeq.Size(r.rcvNxt)
+ segLen -= diff
+ segSeq.UpdateForward(diff)
+ s.sequenceNumber.UpdateForward(diff)
+ s.data.TrimFront(int(diff))
+ }
+
+ // Move segment to ready-to-deliver list. Wakeup any waiters.
+ r.ep.readyToRead(s)
+
+ } else if segSeq != r.rcvNxt {
+ return false
+ }
+
+ // Update the segment that we're expecting to consume.
+ r.rcvNxt = segSeq.Add(segLen)
+
+ // Trim SACK Blocks to remove any SACK information that covers
+ // sequence numbers that have been consumed.
+ TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+
+ if s.flagIsSet(header.TCPFlagFin) {
+ r.rcvNxt++
+
+ // Send ACK immediately.
+ r.ep.snd.sendAck()
+
+ // Tell any readers that no more data will come.
+ r.closed = true
+ r.ep.readyToRead(nil)
+
+ // Flush out any pending segments, except the very first one if
+ // it happens to be the one we're handling now because the
+ // caller is using it.
+ first := 0
+ if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s {
+ first = 1
+ }
+
+ for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingRcvdSegments[i].decRef()
+ }
+ r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+ }
+
+ return true
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) {
+ // We don't care about receive processing anymore if the receive side
+ // is closed.
+ if r.closed {
+ return
+ }
+
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // If the sequence number range is outside the acceptable range, just
+ // send an ACK. This is according to RFC 793, page 37.
+ if !r.acceptable(segSeq, segLen) {
+ r.ep.snd.sendAck()
+ return
+ }
+
+ // Defer segment processing if it can't be consumed now.
+ if !r.consumeSegment(s, segSeq, segLen) {
+ if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
+ // We only store the segment if it's within our buffer
+ // size limit.
+ if r.pendingBufUsed < r.pendingBufSize {
+ r.pendingBufUsed += s.logicalLen()
+ s.incRef()
+ heap.Push(&r.pendingRcvdSegments, s)
+ }
+
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+
+ // Immediately send an ack so that the peer knows it may
+ // have to retransmit.
+ r.ep.snd.sendAck()
+ }
+ return
+ }
+
+ // By consuming the current segment, we may have filled a gap in the
+ // sequence number domain that allows pending segments to be consumed
+ // now. So try to do it.
+ for !r.closed && r.pendingRcvdSegments.Len() > 0 {
+ s := r.pendingRcvdSegments[0]
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // Skip segment altogether if it has already been acknowledged.
+ if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ !r.consumeSegment(s, segSeq, segLen) {
+ break
+ }
+
+ heap.Pop(&r.pendingRcvdSegments)
+ r.pendingBufUsed -= s.logicalLen()
+ s.decRef()
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
new file mode 100644
index 000000000..f83ebc717
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoState stores the variables related to TCP New Reno congestion
+// control algorithm.
+//
+// +stateify savable
+type renoState struct {
+ s *sender
+}
+
+// newRenoCC initializes the state for the NewReno congestion control algorithm.
+func newRenoCC(s *sender) *renoState {
+ return &renoState{s: s}
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window
+// we cross the SSthreshold then it will return the number of packets that
+// must be consumed in congestion avoidance mode.
+func (r *renoState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := r.s.sndCwnd + packetsAcked
+ if newcwnd >= r.s.sndSsthresh {
+ newcwnd = r.s.sndSsthresh
+ r.s.sndCAAckCount = 0
+ }
+
+ packetsAcked -= newcwnd - r.s.sndCwnd
+ r.s.sndCwnd = newcwnd
+ return packetsAcked
+}
+
+// updateCongestionAvoidance will update congestion window in congestion
+// avoidance mode as described in RFC5681 section 3.1
+func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
+ // Consume the packets in congestion avoidance mode.
+ r.s.sndCAAckCount += packetsAcked
+ if r.s.sndCAAckCount >= r.s.sndCwnd {
+ r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
+ r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ }
+}
+
+// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
+// page 6, eq. 4. It is called when we detect congestion in the network.
+func (r *renoState) reduceSlowStartThreshold() {
+ r.s.sndSsthresh = r.s.outstanding / 2
+ if r.s.sndSsthresh < 2 {
+ r.s.sndSsthresh = 2
+ }
+
+}
+
+// Update updates the congestion state based on the number of packets that
+// were acknowledged.
+// Update implements congestionControl.Update.
+func (r *renoState) Update(packetsAcked int) {
+ if r.s.sndCwnd < r.s.sndSsthresh {
+ packetsAcked = r.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ }
+ r.updateCongestionAvoidance(packetsAcked)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (r *renoState) HandleNDupAcks() {
+ // A retransmit was triggered due to nDupAckThreshold
+ // being hit. Reduce our slow start threshold.
+ r.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionControl.HandleRTOExpired.
+func (r *renoState) HandleRTOExpired() {
+ // We lost a packet, so reduce ssthresh.
+ r.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ r.s.sndCwnd = 1
+}
+
+// PostRecovery implements congestionControl.PostRecovery.
+func (r *renoState) PostRecovery() {
+ // noop.
+}
diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go
new file mode 100644
index 000000000..6a013d99b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // MaxSACKBlocks is the maximum number of SACK blocks stored
+ // at receiver side.
+ MaxSACKBlocks = 6
+)
+
+// UpdateSACKBlocks updates the list of SACK blocks to include the segment
+// specified by segStart->segEnd. If the segment happens to be an out of order
+// delivery then the first block in the sack.blocks always includes the
+// segment identified by segStart->segEnd.
+func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
+ newSB := header.SACKBlock{Start: segStart, End: segEnd}
+ if sack.NumBlocks == 0 {
+ sack.Blocks[0] = newSB
+ sack.NumBlocks = 1
+ return
+ }
+ var n = 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ start, end := sack.Blocks[i].Start, sack.Blocks[i].End
+ if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
+ // Discard any invalid blocks where end is before start
+ // and discard any sack blocks that are before rcvNxt as
+ // those have already been acked.
+ continue
+ }
+ if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
+ // Merge this SACK block into newSB and discard this SACK
+ // block.
+ if start.LessThan(newSB.Start) {
+ newSB.Start = start
+ }
+ if newSB.End.LessThan(end) {
+ newSB.End = end
+ }
+ } else {
+ // Save this block.
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ }
+ if rcvNxt.LessThan(newSB.Start) {
+ // If this was an out of order segment then make sure that the
+ // first SACK block is the one that includes the segment.
+ //
+ // See the first bullet point in
+ // https://tools.ietf.org/html/rfc2018#section-4
+ if n == MaxSACKBlocks {
+ // If the number of SACK blocks is equal to
+ // MaxSACKBlocks then discard the last SACK block.
+ n--
+ }
+ for i := n - 1; i >= 0; i-- {
+ sack.Blocks[i+1] = sack.Blocks[i]
+ }
+ sack.Blocks[0] = newSB
+ n++
+ }
+ sack.NumBlocks = n
+}
+
+// TrimSACKBlockList updates the sack block list by removing/modifying any block
+// where start is < rcvNxt.
+func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
+ n := 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ if sack.Blocks[i].End.LessThanEq(rcvNxt) {
+ continue
+ }
+ if sack.Blocks[i].Start.LessThan(rcvNxt) {
+ // Shrink this SACK block.
+ sack.Blocks[i].Start = rcvNxt
+ }
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ sack.NumBlocks = n
+}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
new file mode 100644
index 000000000..1c5766a42
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/google/btree"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // maxSACKBlocks is the maximum number of distinct SACKBlocks the
+ // scoreboard will track. Once there are 100 distinct blocks, new
+ // insertions will fail.
+ maxSACKBlocks = 100
+
+ // defaultBtreeDegree is set to 2 as btree.New(2) results in a 2-3-4
+ // tree.
+ defaultBtreeDegree = 2
+)
+
+// SACKScoreboard stores a set of disjoint SACK ranges.
+//
+// +stateify savable
+type SACKScoreboard struct {
+ // smss is defined in RFC5681 as following:
+ //
+ // The SMSS is the size of the largest segment that the sender can
+ // transmit. This value can be based on the maximum transmission unit
+ // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm,
+ // RMSS (see next item), or other factors. The size does not include
+ // the TCP/IP headers and options.
+ smss uint16
+ maxSACKED seqnum.Value
+ sacked seqnum.Size `state:"nosave"`
+ ranges *btree.BTree `state:"nosave"`
+}
+
+// NewSACKScoreboard returns a new SACK Scoreboard.
+func NewSACKScoreboard(smss uint16, iss seqnum.Value) *SACKScoreboard {
+ return &SACKScoreboard{
+ smss: smss,
+ ranges: btree.New(defaultBtreeDegree),
+ maxSACKED: iss,
+ }
+}
+
+// Reset erases all known range information from the SACK scoreboard.
+func (s *SACKScoreboard) Reset() {
+ s.ranges = btree.New(defaultBtreeDegree)
+ s.sacked = 0
+}
+
+// Insert inserts/merges the provided SACKBlock into the scoreboard.
+func (s *SACKScoreboard) Insert(r header.SACKBlock) {
+ if s.ranges.Len() >= maxSACKBlocks {
+ return
+ }
+
+ // Check if we can merge the new range with a range before or after it.
+ var toDelete []btree.Item
+ if s.maxSACKED.LessThan(r.End - 1) {
+ s.maxSACKED = r.End - 1
+ }
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // There is a hole between these two SACK blocks, so we can't
+ // merge anymore.
+ if r.End.LessThan(sacked.Start) {
+ return false
+ }
+ // There is some overlap at this point, merge the blocks and
+ // delete the other one.
+ //
+ // ----sS--------sE
+ // r.S---------------rE
+ // -------sE
+ if sacked.End.LessThan(r.End) {
+ // sacked is contained in the newly inserted range.
+ // Delete this block.
+ toDelete = append(toDelete, i)
+ return true
+ }
+ // sacked covers a range past end of the newly inserted
+ // block.
+ r.End = sacked.End
+ toDelete = append(toDelete, i)
+ return true
+ })
+
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // sA------sE
+ // rA----rE
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ // The previous range extends into the current block. Merge it
+ // into the newly inserted range and delete the other one.
+ //
+ // <-rA---rE----<---rE--->
+ // sA--------------sE
+ r.Start = sacked.Start
+ // Extend r to cover sacked if sacked extends past r.
+ if r.End.LessThan(sacked.End) {
+ r.End = sacked.End
+ }
+ toDelete = append(toDelete, i)
+ return true
+ })
+ for _, i := range toDelete {
+ if sb := s.ranges.Delete(i); sb != nil {
+ sb := i.(header.SACKBlock)
+ s.sacked -= sb.Start.Size(sb.End)
+ }
+ }
+
+ replaced := s.ranges.ReplaceOrInsert(r)
+ if replaced == nil {
+ s.sacked += r.Start.Size(r.End)
+ }
+}
+
+// IsSACKED returns true if the a given range of sequence numbers denoted by r
+// are already covered by SACK information in the scoreboard.
+func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+
+ found := false
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ if sacked.Contains(r) {
+ found = true
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+// Dump prints the state of the scoreboard structure.
+func (s *SACKScoreboard) String() string {
+ var str strings.Builder
+ str.WriteString("SACKScoreboard: {")
+ s.ranges.Ascend(func(i btree.Item) bool {
+ str.WriteString(fmt.Sprintf("%v,", i))
+ return true
+ })
+ str.WriteString("}\n")
+ return str.String()
+}
+
+// Delete removes all SACK information prior to seq.
+func (s *SACKScoreboard) Delete(seq seqnum.Value) {
+ if s.Empty() {
+ return
+ }
+ toDelete := []btree.Item{}
+ toInsert := []btree.Item{}
+ r := header.SACKBlock{seq, seq.Add(1)}
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sb := i.(header.SACKBlock)
+ toDelete = append(toDelete, i)
+ if sb.End.LessThanEq(seq) {
+ s.sacked -= sb.Start.Size(sb.End)
+ } else {
+ newSB := header.SACKBlock{seq, sb.End}
+ toInsert = append(toInsert, newSB)
+ s.sacked -= sb.Start.Size(seq)
+ }
+ return true
+ })
+ for _, sb := range toDelete {
+ s.ranges.Delete(sb)
+ }
+ for _, sb := range toInsert {
+ s.ranges.ReplaceOrInsert(sb)
+ }
+}
+
+// Copy provides a copy of the SACK scoreboard.
+func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum.Value) {
+ s.ranges.Ascend(func(i btree.Item) bool {
+ sackBlocks = append(sackBlocks, i.(header.SACKBlock))
+ return true
+ })
+ return sackBlocks, s.maxSACKED
+}
+
+// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675
+// section 4 but operates on a range of sequence numbers and returns true if
+// there are at least nDupAckThreshold SACK blocks greater than the range being
+// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED
+// with sequence numbers greater than the block being checked.
+func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+ nDupSACK := 0
+ nDupSACKBytes := seqnum.Size(0)
+ isLost := false
+
+ // We need to check if the immediate lower (if any) sacked
+ // range contains or partially overlaps with r.
+ searchMore := true
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ searchMore = false
+ return false
+ }
+ if sacked.End.LessThanEq(r.Start) {
+ // all sequence numbers covered by sacked are below
+ // r so we continue searching.
+ return false
+ }
+ // There is a partial overlap. In this case we r.Start is
+ // between sacked.Start & sacked.End and r.End extends beyond
+ // sacked.End.
+ // Move r.Start to sacked.End and continuing searching blocks
+ // above r.Start.
+ r.Start = sacked.End
+ return false
+ })
+
+ if !searchMore {
+ return isLost
+ }
+
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ return false
+ }
+ nDupSACKBytes += sacked.Start.Size(sacked.End)
+ nDupSACK++
+ if nDupSACK >= nDupAckThreshold || nDupSACKBytes >= seqnum.Size((nDupAckThreshold-1)*s.smss) {
+ isLost = true
+ return false
+ }
+ return true
+ })
+ return isLost
+}
+
+// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section
+// 4.
+//
+// This routine returns whether the given sequence number is considered to be
+// lost. The routine returns true when either nDupAckThreshold discontiguous
+// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS)
+// bytes with sequence numbers greater than 'SeqNum' have been SACKed.
+// Otherwise, the routine returns false.
+func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool {
+ return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)})
+}
+
+// Empty returns true if the SACK scoreboard has no entries, false otherwise.
+func (s *SACKScoreboard) Empty() bool {
+ return s.ranges.Len() == 0
+}
+
+// Sacked returns the current number of bytes held in the SACK scoreboard.
+func (s *SACKScoreboard) Sacked() seqnum.Size {
+ return s.sacked
+}
+
+// MaxSACKED returns the highest sequence number ever inserted in the SACK
+// scoreboard.
+func (s *SACKScoreboard) MaxSACKED() seqnum.Value {
+ return s.maxSACKED
+}
+
+// SMSS returns the sender's MSS as held by the SACK scoreboard.
+func (s *SACKScoreboard) SMSS() uint16 {
+ return s.smss
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
new file mode 100644
index 000000000..450d9fbc1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -0,0 +1,186 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// segment represents a TCP segment. It holds the payload and parsed TCP segment
+// information, and can be added to intrusive lists.
+// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+//
+// +stateify savable
+type segment struct {
+ segmentEntry
+ refCnt int32
+ id stack.TransportEndpointID `state:"manual"`
+ route stack.Route `state:"manual"`
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+ // viewToDeliver keeps track of the next View that should be
+ // delivered by the Read endpoint.
+ viewToDeliver int
+ sequenceNumber seqnum.Value
+ ackNumber seqnum.Value
+ flags uint8
+ window seqnum.Size
+ // csum is only populated for received segments.
+ csum uint16
+ // csumValid is true if the csum in the received segment is valid.
+ csumValid bool
+
+ // parsedOptions stores the parsed values from the options in the segment.
+ parsedOptions header.TCPOptions
+ options []byte `state:".([]byte)"`
+ hasNewSACKInfo bool
+ rcvdTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment. A zero value
+ // indicates that the segment has yet to be transmitted.
+ xmitTime time.Time `state:".(unixTime)"`
+}
+
+func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.data = vv.Clone(s.views[:])
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func (s *segment) clone() *segment {
+ t := &segment{
+ refCnt: 1,
+ id: s.id,
+ sequenceNumber: s.sequenceNumber,
+ ackNumber: s.ackNumber,
+ flags: s.flags,
+ window: s.window,
+ route: s.route.Clone(),
+ viewToDeliver: s.viewToDeliver,
+ rcvdTime: s.rcvdTime,
+ }
+ t.data = s.data.Clone(t.views[:])
+ return t
+}
+
+func (s *segment) flagIsSet(flag uint8) bool {
+ return (s.flags & flag) != 0
+}
+
+func (s *segment) decRef() {
+ if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ s.route.Release()
+ }
+}
+
+func (s *segment) incRef() {
+ atomic.AddInt32(&s.refCnt, 1)
+}
+
+// logicalLen is the segment length in the sequence number space. It's defined
+// as the data length plus one for each of the SYN and FIN bits set.
+func (s *segment) logicalLen() seqnum.Size {
+ l := seqnum.Size(s.data.Size())
+ if s.flagIsSet(header.TCPFlagSyn) {
+ l++
+ }
+ if s.flagIsSet(header.TCPFlagFin) {
+ l++
+ }
+ return l
+}
+
+// parse populates the sequence & ack numbers, flags, and window fields of the
+// segment from the TCP header stored in the data. It then updates the view to
+// skip the header.
+//
+// Returns boolean indicating if the parsing was successful.
+//
+// If checksum verification is not offloaded then parse also verifies the
+// TCP checksum and stores the checksum and result of checksum verification in
+// the csum and csumValid fields of the segment.
+func (s *segment) parse() bool {
+ h := header.TCP(s.data.First())
+
+ // h is the header followed by the payload. We check that the offset to
+ // the data respects the following constraints:
+ // 1. That it's at least the minimum header size; if we don't do this
+ // then part of the header would be delivered to user.
+ // 2. That the header fits within the buffer; if we don't do this, we
+ // would panic when we tried to access data beyond the buffer.
+ //
+ // N.B. The segment has already been validated as having at least the
+ // minimum TCP size before reaching here, so it's safe to read the
+ // fields.
+ offset := int(h.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(h) {
+ return false
+ }
+
+ s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.parsedOptions = header.ParseTCPOptions(s.options)
+
+ // Query the link capabilities to decide if checksum validation is
+ // required.
+ verifyChecksum := true
+ if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ s.csumValid = true
+ verifyChecksum = false
+ s.data.TrimFront(offset)
+ }
+ if verifyChecksum {
+ s.csum = h.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
+ xsum = h.CalculateChecksum(xsum)
+ s.data.TrimFront(offset)
+ xsum = header.ChecksumVV(s.data, xsum)
+ s.csumValid = xsum == 0xffff
+ }
+
+ s.sequenceNumber = seqnum.Value(h.SequenceNumber())
+ s.ackNumber = seqnum.Value(h.AckNumber())
+ s.flags = h.Flags()
+ s.window = seqnum.Size(h.WindowSize())
+ return true
+}
+
+// sackBlock returns a header.SACKBlock that represents this segment.
+func (s *segment) sackBlock() header.SACKBlock {
+ return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
new file mode 100644
index 000000000..9fd061d7d
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+type segmentHeap []*segment
+
+// Len returns the length of h.
+func (h segmentHeap) Len() int {
+ return len(h)
+}
+
+// Less determines whether the i-th element of h is less than the j-th element.
+func (h segmentHeap) Less(i, j int) bool {
+ return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+}
+
+// Swap swaps the i-th and j-th elements of h.
+func (h segmentHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+// Push adds x as the last element of h.
+func (h *segmentHeap) Push(x interface{}) {
+ *h = append(*h, x.(*segment))
+}
+
+// Pop removes the last element of h and returns it.
+func (h *segmentHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
new file mode 100644
index 000000000..e0759225e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+)
+
+// segmentQueue is a bounded, thread-safe queue of TCP segments.
+//
+// +stateify savable
+type segmentQueue struct {
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ limit int
+ used int
+}
+
+// empty determines if the queue is empty.
+func (q *segmentQueue) empty() bool {
+ q.mu.Lock()
+ r := q.used == 0
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No segments are immediately dropped in case the
+// queue becomes full due to the new limit.
+func (q *segmentQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given segment to the queue.
+//
+// Returns true when the segment is successfully added to the queue, in which
+// case ownership of the reference is transferred to the queue. And returns
+// false if the queue is full, in which case ownership is retained by the
+// caller.
+func (q *segmentQueue) enqueue(s *segment) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used++
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next segment from queue, if one exists.
+// Ownership is transferred to the caller, who is responsible for decrementing
+// the ref count when done.
+func (q *segmentQueue) dequeue() *segment {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used--
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
new file mode 100644
index 000000000..dd7e14aa6
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// saveData is invoked by stateify.
+func (s *segment) saveData() buffer.VectorisedView {
+ // We cannot save s.data directly as s.data.views may alias to s.views,
+ // which is not allowed by state framework (in-struct pointer).
+ v := make([]buffer.View, len(s.data.Views()))
+ // For views already delivered, we cannot save them directly as they may
+ // have already been sliced and saved elsewhere (e.g., readViews).
+ for i := 0; i < s.viewToDeliver; i++ {
+ v[i] = append([]byte(nil), s.data.Views()[i]...)
+ }
+ for i := s.viewToDeliver; i < len(v); i++ {
+ v[i] = s.data.Views()[i]
+ }
+ return buffer.NewVectorisedView(s.data.Size(), v)
+}
+
+// loadData is invoked by stateify.
+func (s *segment) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing s.views for data.views.
+ s.data = data
+}
+
+// saveOptions is invoked by stateify.
+func (s *segment) saveOptions() []byte {
+ // We cannot save s.options directly as it may point to s.data's trimmed
+ // tail, which is not allowed by state framework (in-struct pointer).
+ b := make([]byte, 0, cap(s.options))
+ return append(b, s.options...)
+}
+
+// loadOptions is invoked by stateify.
+func (s *segment) loadOptions(options []byte) {
+ // NOTE: We cannot point s.options back into s.data's trimmed tail. But
+ // it is OK as they do not need to aliased. Plus, options is already
+ // allocated so there is no cost here.
+ s.options = options
+}
+
+// saveRcvdTime is invoked by stateify.
+func (s *segment) saveRcvdTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadRcvdTime is invoked by stateify.
+func (s *segment) loadRcvdTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (s *segment) saveXmitTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (s *segment) loadXmitTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
new file mode 100644
index 000000000..afc1d0a55
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -0,0 +1,1180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // minRTO is the minimum allowed value for the retransmit timeout.
+ minRTO = 200 * time.Millisecond
+
+ // InitialCwnd is the initial congestion window.
+ InitialCwnd = 10
+
+ // nDupAckThreshold is the number of duplicate ACK's required
+ // before fast-retransmit is entered.
+ nDupAckThreshold = 3
+)
+
+// congestionControl is an interface that must be implemented by any supported
+// congestion control algorithm.
+type congestionControl interface {
+ // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold
+ // just before entering fast retransmit.
+ HandleNDupAcks()
+
+ // HandleRTOExpired is invoked when the retransmit timer expires.
+ HandleRTOExpired()
+
+ // Update is invoked when processing inbound acks. It's passed the
+ // number of packet's that were acked by the most recent cumulative
+ // acknowledgement.
+ Update(packetsAcked int)
+
+ // PostRecovery is invoked when the sender is exiting a fast retransmit/
+ // recovery phase. This provides congestion control algorithms a way
+ // to adjust their state when exiting recovery.
+ PostRecovery()
+}
+
+// sender holds the state necessary to send TCP segments.
+//
+// +stateify savable
+type sender struct {
+ ep *endpoint
+
+ // lastSendTime is the timestamp when the last packet was sent.
+ lastSendTime time.Time `state:".(unixTime)"`
+
+ // dupAckCount is the number of duplicated acks received. It is used for
+ // fast retransmit.
+ dupAckCount int
+
+ // fr holds state related to fast recovery.
+ fr fastRecovery
+
+ // sndCwnd is the congestion window, in packets.
+ sndCwnd int
+
+ // sndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ sndSsthresh int
+
+ // sndCAAckCount is the number of packets acknowledged during congestion
+ // avoidance. When enough packets have been ack'd (typically cwnd
+ // packets), the congestion window is incremented by one.
+ sndCAAckCount int
+
+ // outstanding is the number of outstanding packets, that is, packets
+ // that have been sent but not yet acknowledged.
+ outstanding int
+
+ // sndWnd is the send window size.
+ sndWnd seqnum.Size
+
+ // sndUna is the next unacknowledged sequence number.
+ sndUna seqnum.Value
+
+ // sndNxt is the sequence number of the next segment to be sent.
+ sndNxt seqnum.Value
+
+ // sndNxtList is the sequence number of the next segment to be added to
+ // the send list.
+ sndNxtList seqnum.Value
+
+ // rttMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ rttMeasureSeqNum seqnum.Value
+
+ // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
+ rttMeasureTime time.Time `state:".(unixTime)"`
+
+ closed bool
+ writeNext *segment
+ writeList segmentList
+ resendTimer timer `state:"nosave"`
+ resendWaker sleep.Waker `state:"nosave"`
+
+ // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
+ // "round-trip time variation" and "retransmit timeout", as defined in
+ // section 2 of RFC 6298.
+ rtt rtt
+ rto time.Duration
+ srttInited bool
+
+ // maxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ maxPayloadSize int
+
+ // gso is set if generic segmentation offload is enabled.
+ gso bool
+
+ // sndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ sndWndScale uint8
+
+ // maxSentAck is the maxium acknowledgement actually sent.
+ maxSentAck seqnum.Value
+
+ // cc is the congestion control algorithm in use for this sender.
+ cc congestionControl
+}
+
+// rtt is a synchronization wrapper used to appease stateify. See the comment
+// in sender, where it is used.
+//
+// +stateify savable
+type rtt struct {
+ sync.Mutex `state:"nosave"`
+
+ srtt time.Duration
+ rttvar time.Duration
+}
+
+// fastRecovery holds information related to fast recovery from a packet loss.
+//
+// +stateify savable
+type fastRecovery struct {
+ // active whether the endpoint is in fast recovery. The following fields
+ // are only meaningful when active is true.
+ active bool
+
+ // first and last represent the inclusive sequence number range being
+ // recovered.
+ first seqnum.Value
+ last seqnum.Value
+
+ // maxCwnd is the maximum value the congestion window may be inflated to
+ // due to duplicate acks. This exists to avoid attacks where the
+ // receiver intentionally sends duplicate acks to artificially inflate
+ // the sender's cwnd.
+ maxCwnd int
+
+ // highRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ highRxt seqnum.Value
+
+ // rescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ rescueRxt seqnum.Value
+}
+
+func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
+ // The sender MUST reduce the TCP data length to account for any IP or
+ // TCP options that it is including in the packets that it sends.
+ // See: https://tools.ietf.org/html/rfc6691#section-2
+ maxPayloadSize := int(mss) - ep.maxOptionSize()
+
+ s := &sender{
+ ep: ep,
+ sndCwnd: InitialCwnd,
+ sndSsthresh: math.MaxInt64,
+ sndWnd: sndWnd,
+ sndUna: iss + 1,
+ sndNxt: iss + 1,
+ sndNxtList: iss + 1,
+ rto: 1 * time.Second,
+ rttMeasureSeqNum: iss + 1,
+ lastSendTime: time.Now(),
+ maxPayloadSize: maxPayloadSize,
+ maxSentAck: irs + 1,
+ fr: fastRecovery{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ last: iss,
+ highRxt: iss,
+ rescueRxt: iss,
+ },
+ gso: ep.gso != nil,
+ }
+
+ if s.gso {
+ s.ep.gso.MSS = uint16(maxPayloadSize)
+ }
+
+ s.cc = s.initCongestionControl(ep.cc)
+
+ // A negative sndWndScale means that no scaling is in use, otherwise we
+ // store the scaling value.
+ if sndWndScale > 0 {
+ s.sndWndScale = uint8(sndWndScale)
+ }
+
+ s.resendTimer.init(&s.resendWaker)
+
+ s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
+
+ // Initialize SACK Scoreboard after updating max payload size as we use
+ // the maxPayloadSize as the smss when determining if a segment is lost
+ // etc.
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+
+ return s
+}
+
+func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+ switch congestionControlName {
+ case ccCubic:
+ return newCubicCC(s)
+ case ccReno:
+ fallthrough
+ default:
+ return newRenoCC(s)
+ }
+}
+
+// updateMaxPayloadSize updates the maximum payload size based on the given
+// MTU. If this is in response to "packet too big" control packets (indicated
+// by the count argument), it also reduces the number of outstanding packets and
+// attempts to retransmit the first packet above the MTU size.
+func (s *sender) updateMaxPayloadSize(mtu, count int) {
+ m := mtu - header.TCPMinimumSize
+
+ m -= s.ep.maxOptionSize()
+
+ // We don't adjust up for now.
+ if m >= s.maxPayloadSize {
+ return
+ }
+
+ // Make sure we can transmit at least one byte.
+ if m <= 0 {
+ m = 1
+ }
+
+ s.maxPayloadSize = m
+ if s.gso {
+ s.ep.gso.MSS = uint16(m)
+ }
+
+ if count == 0 {
+ // updateMaxPayloadSize is also called when the sender is created.
+ // and there is no data to send in such cases. Return immediately.
+ return
+ }
+
+ // Update the scoreboard's smss to reflect the new lowered
+ // maxPayloadSize.
+ s.ep.scoreboard.smss = uint16(m)
+
+ s.outstanding -= count
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ // Rewind writeNext to the first segment exceeding the MTU. Do nothing
+ // if it is already before such a packet.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if seg == s.writeNext {
+ // We got to writeNext before we could find a segment
+ // exceeding the MTU.
+ break
+ }
+
+ if seg.data.Size() > m {
+ // We found a segment exceeding the MTU. Rewind
+ // writeNext and try to retransmit it.
+ s.writeNext = seg
+ break
+ }
+ }
+
+ // Since we likely reduced the number of outstanding packets, we may be
+ // ready to send some more.
+ s.sendData()
+}
+
+// sendAck sends an ACK segment.
+func (s *sender) sendAck() {
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+}
+
+// updateRTO updates the retransmit timeout when a new roud-trip time is
+// available. This is done in accordance with section 2 of RFC 6298.
+func (s *sender) updateRTO(rtt time.Duration) {
+ s.rtt.Lock()
+ if !s.srttInited {
+ s.rtt.rttvar = rtt / 2
+ s.rtt.srtt = rtt
+ s.srttInited = true
+ } else {
+ diff := s.rtt.srtt - rtt
+ if diff < 0 {
+ diff = -diff
+ }
+ // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // no timestamps are available.
+ if !s.ep.sendTSOk {
+ s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
+ s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ } else {
+ // When we are taking RTT measurements of every ACK then
+ // we need to use a modified method as specified in
+ // https://tools.ietf.org/html/rfc7323#appendix-G
+ if s.outstanding == 0 {
+ s.rtt.Unlock()
+ return
+ }
+ // Netstack measures congestion window/inflight all in
+ // terms of packets and not bytes. This is similar to
+ // how linux also does cwnd and inflight. In practice
+ // this approximation works as expected.
+ expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+
+ // alpha & beta values are the original values as recommended in
+ // https://tools.ietf.org/html/rfc6298#section-2.3.
+ const alpha = 0.125
+ const beta = 0.25
+
+ alphaPrime := alpha / expectedSamples
+ betaPrime := beta / expectedSamples
+ rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ }
+ }
+
+ s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.rtt.Unlock()
+ if s.rto < minRTO {
+ s.rto = minRTO
+ }
+}
+
+// resendSegment resends the first unacknowledged segment.
+func (s *sender) resendSegment() {
+ // Don't use any segments we already sent to measure RTT as they may
+ // have been affected by packets being lost.
+ s.rttMeasureSeqNum = s.sndNxt
+
+ // Resend the segment.
+ if seg := s.writeList.Front(); seg != nil {
+ if seg.data.Size() > s.maxPayloadSize {
+ s.splitSeg(seg, s.maxPayloadSize)
+ }
+
+ // See: RFC 6675 section 5 Step 4.3
+ //
+ // To prevent retransmission, set both the HighRXT and RescueRXT
+ // to the highest sequence number in the retransmitted segment.
+ s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.sendSegment(seg)
+ s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+
+ // Run SetPipe() as per RFC 6675 section 5 Step 4.4
+ s.SetPipe()
+ }
+}
+
+// retransmitTimerExpired is called when the retransmit timer expires, and
+// unacknowledged segments are assumed lost, and thus need to be resent.
+// Returns true if the connection is still usable, or false if the connection
+// is deemed lost.
+func (s *sender) retransmitTimerExpired() bool {
+ // Check if the timer actually expired or if it's a spurious wake due
+ // to a previously orphaned runtime timer.
+ if !s.resendTimer.checkExpiration() {
+ return true
+ }
+
+ s.ep.stack.Stats().TCP.Timeouts.Increment()
+
+ // Give up if we've waited more than a minute since the last resend.
+ if s.rto >= 60*time.Second {
+ return false
+ }
+
+ // Set new timeout. The timer will be restarted by the call to sendData
+ // below.
+ s.rto *= 2
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ //
+ // Retransmit timeouts:
+ // After a retransmit timeout, record the highest sequence number
+ // transmitted in the variable recover, and exit the fast recovery
+ // procedure if applicable.
+ s.fr.last = s.sndNxt - 1
+
+ if s.fr.active {
+ // We were attempting fast recovery but were not successful.
+ // Leave the state. We don't need to update ssthresh because it
+ // has already been updated when entered fast-recovery.
+ s.leaveFastRecovery()
+ }
+
+ s.cc.HandleRTOExpired()
+
+ // Mark the next segment to be sent as the first unacknowledged one and
+ // start sending again. Set the number of outstanding packets to 0 so
+ // that we'll be able to retransmit.
+ //
+ // We'll keep on transmitting (or retransmitting) as we get acks for
+ // the data we transmit.
+ s.outstanding = 0
+
+ // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1
+ //
+ // In order to avoid memory deadlocks, the TCP receiver is allowed to
+ // discard data that has already been selectively acknowledged. As a
+ // result, [RFC2018] suggests that a TCP sender SHOULD expunge the SACK
+ // information gathered from a receiver upon a retransmission timeout
+ // (RTO) "since the timeout might indicate that the data receiver has
+ // reneged." Additionally, a TCP sender MUST "ignore prior SACK
+ // information in determining which data to retransmit."
+ //
+ // NOTE: We take the stricter interpretation and just expunge all
+ // information as we lack more rigorous checks to validate if the SACK
+ // information is usable after an RTO.
+ s.ep.scoreboard.Reset()
+ s.writeNext = s.writeList.Front()
+ s.sendData()
+
+ return true
+}
+
+// pCount returns the number of packets in the segment. Due to GSO, a segment
+// can be composed of multiple packets.
+func (s *sender) pCount(seg *segment) int {
+ size := seg.data.Size()
+ if size == 0 {
+ return 1
+ }
+
+ return (size-1)/s.maxPayloadSize + 1
+}
+
+// splitSeg splits a given segment at the size specified and inserts the
+// remainder as a new segment after the current one in the write list.
+func (s *sender) splitSeg(seg *segment, size int) {
+ if seg.data.Size() <= size {
+ return
+ }
+ // Split this segment up.
+ nSeg := seg.clone()
+ nSeg.data.TrimFront(size)
+ nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
+ s.writeList.InsertAfter(seg, nSeg)
+ seg.data.CapLength(size)
+}
+
+// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that
+// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2
+// is handled by the normal send logic.
+func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
+ var s3 *segment
+ var s4 *segment
+ smss := s.ep.scoreboard.SMSS()
+ // Step 1.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if !s.isAssignedSequenceNumber(seg) {
+ break
+ }
+ segSeq := seg.sequenceNumber
+ if seg.data.Size() > int(smss) {
+ s.splitSeg(seg, int(smss))
+ }
+ // See RFC 6675 Section 4
+ //
+ // 1. If there exists a smallest unSACKED sequence number
+ // 'S2' that meets the following 3 criteria for determinig
+ // loss, the sequence range of one segment of up to SMSS
+ // octects starting with S2 MUST be returned.
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ // NextSeg():
+ //
+ // (1.a) S2 is greater than HighRxt
+ // (1.b) S2 is less than highest octect covered by
+ // any received SACK.
+ if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ // NextSeg():
+ // (1.c) IsLost(S2) returns true.
+ if s.ep.scoreboard.IsLost(segSeq) {
+ return seg, s3, s4
+ }
+ // NextSeg():
+ //
+ // (3): If the conditions for rules (1) and (2)
+ // fail, but there exists an unSACKed sequence
+ // number S3 that meets the criteria for
+ // detecting loss given in steps 1.a and 1.b
+ // above (specifically excluding (1.c)) then one
+ // segment of upto SMSS octets starting with S3
+ // SHOULD be returned.
+ if s3 == nil {
+ s3 = seg
+ }
+ }
+ // NextSeg():
+ //
+ // (4) If the conditions for (1), (2) and (3) fail,
+ // but there exists outstanding unSACKED data, we
+ // provide the opportunity for a single "rescue"
+ // retransmission per entry into loss recovery. If
+ // HighACK is greater than RescueRxt, the one
+ // segment of upto SMSS octects that MUST include
+ // the highest outstanding unSACKed sequence number
+ // SHOULD be returned.
+ if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s4 != nil {
+ if s4.sequenceNumber.LessThan(segSeq) {
+ s4 = seg
+ }
+ } else {
+ s4 = seg
+ }
+ s.fr.rescueRxt = s.fr.last
+ }
+ }
+ }
+
+ return nil, s3, s4
+}
+
+// maybeSendSegment tries to send the specified segment and either coalesces
+// other segments into this one or splits the specified segment based on the
+// lower of the specified limit value or the receivers window size specified by
+// end.
+func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) {
+ // We abuse the flags field to determine if we have already
+ // assigned a sequence number to this segment.
+ if !s.isAssignedSequenceNumber(seg) {
+ // Merge segments if allowed.
+ if seg.data.Size() != 0 {
+ available := int(seg.sequenceNumber.Size(end))
+ if available > limit {
+ available = limit
+ }
+
+ // nextTooBig indicates that the next segment was too
+ // large to entirely fit in the current segment. It
+ // would be possible to split the next segment and merge
+ // the portion that fits, but unexpectedly splitting
+ // segments can have user visible side-effects which can
+ // break applications. For example, RFC 7766 section 8
+ // says that the length and data of a DNS response
+ // should be sent in the same TCP segment to avoid
+ // triggering bugs in poorly written DNS
+ // implementations.
+ var nextTooBig bool
+ for seg.Next() != nil && seg.Next().data.Size() != 0 {
+ if seg.data.Size()+seg.Next().data.Size() > available {
+ nextTooBig = true
+ break
+ }
+ seg.data.Append(seg.Next().data)
+
+ // Consume the segment that we just merged in.
+ s.writeList.Remove(seg.Next())
+ }
+ if !nextTooBig && seg.data.Size() < available {
+ // Segment is not full.
+ if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ // Nagle's algorithm. From Wikipedia:
+ // Nagle's algorithm works by
+ // combining a number of small
+ // outgoing messages and sending them
+ // all at once. Specifically, as long
+ // as there is a sent packet for which
+ // the sender has received no
+ // acknowledgment, the sender should
+ // keep buffering its output until it
+ // has a full packet's worth of
+ // output, thus allowing output to be
+ // sent all at once.
+ return false
+ }
+ if atomic.LoadUint32(&s.ep.cork) != 0 {
+ // Hold back the segment until full.
+ return false
+ }
+ }
+ }
+
+ // Assign flags. We don't do it above so that we can merge
+ // additional data if Nagle holds the segment.
+ seg.sequenceNumber = s.sndNxt
+ seg.flags = header.TCPFlagAck | header.TCPFlagPsh
+ }
+
+ var segEnd seqnum.Value
+ if seg.data.Size() == 0 {
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
+ }
+ seg.flags = header.TCPFlagAck | header.TCPFlagFin
+ segEnd = seg.sequenceNumber.Add(1)
+ } else {
+ // We're sending a non-FIN segment.
+ if seg.flags&header.TCPFlagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
+ }
+
+ if !seg.sequenceNumber.LessThan(end) {
+ return false
+ }
+
+ available := int(seg.sequenceNumber.Size(end))
+ if available == 0 {
+ return false
+ }
+ if available > limit {
+ available = limit
+ }
+
+ if seg.data.Size() > available {
+ s.splitSeg(seg, available)
+ }
+
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ }
+
+ s.sendSegment(seg)
+
+ // Update sndNxt if we actually sent new data (as opposed to
+ // retransmitting some previously sent data).
+ if s.sndNxt.LessThan(segEnd) {
+ s.sndNxt = segEnd
+ }
+
+ return true
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ s.SetPipe()
+ for s.outstanding < s.sndCwnd {
+ nextSeg, s3, s4 := s.NextSeg()
+ if nextSeg == nil {
+ // NextSeg():
+ //
+ // Step (2): "If no sequence number 'S2' per rule (1)
+ // exists but there exists available unsent data and the
+ // receiver's advertised window allows, the sequence
+ // range of one segment of up to SMSS octets of
+ // previously unsent data starting with sequence number
+ // HighData+1 MUST be returned."
+ for seg := s.writeNext; seg != nil; seg = seg.Next() {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ continue
+ }
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ nextSeg = seg
+ break
+ }
+ if nextSeg != nil {
+ continue
+ }
+ }
+ rescueRtx := false
+ if nextSeg == nil && s3 != nil {
+ nextSeg = s3
+ }
+ if nextSeg == nil && s4 != nil {
+ nextSeg = s4
+ rescueRtx = true
+ }
+ if nextSeg == nil {
+ break
+ }
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ s.fr.highRxt = segEnd - 1
+ }
+
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ s.outstanding++
+ dataSent = true
+ s.sendSegment(nextSeg)
+ }
+ return dataSent
+}
+
+// sendData sends new data segments. It is called when data becomes available or
+// when the send window opens up.
+func (s *sender) sendData() {
+ limit := s.maxPayloadSize
+ if s.gso {
+ limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
+ }
+ end := s.sndUna.Add(s.sndWnd)
+
+ // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
+ // "A TCP SHOULD set cwnd to no more than RW before beginning
+ // transmission if the TCP has not sent data in the interval exceeding
+ // the retrasmission timeout."
+ if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
+ }
+ }
+
+ var dataSent bool
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ if s.fr.active && s.ep.sackPermitted {
+ dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
+ } else {
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ continue
+ }
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ }
+ }
+
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
+func (s *sender) enterFastRecovery() {
+ s.fr.active = true
+ // Save state to reflect we're now in fast recovery.
+ //
+ // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
+ // We inflate the cwnd by 3 to account for the 3 packets which triggered
+ // the 3 duplicate ACKs and are now not in flight.
+ s.sndCwnd = s.sndSsthresh + 3
+ s.fr.first = s.sndUna
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ if s.ep.sackPermitted {
+ s.ep.stack.Stats().TCP.SACKRecovery.Increment()
+ return
+ }
+ s.ep.stack.Stats().TCP.FastRecovery.Increment()
+}
+
+func (s *sender) leaveFastRecovery() {
+ s.fr.active = false
+ s.fr.maxCwnd = 0
+ s.dupAckCount = 0
+
+ // Deflate cwnd. It had been artificially inflated when new dups arrived.
+ s.sndCwnd = s.sndSsthresh
+
+ s.cc.PostRecovery()
+}
+
+func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ // We are in fast recovery mode. Ignore the ack if it's out of
+ // range.
+ if !ack.InRange(s.sndUna, s.sndNxt+1) {
+ return false
+ }
+
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if s.fr.last.LessThan(ack) {
+ s.leaveFastRecovery()
+ return false
+ }
+
+ if s.ep.sackPermitted {
+ // When SACK is enabled we let retransmission be governed by
+ // the SACK logic.
+ return false
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if seg.logicalLen() != 0 || s.sndWnd != seg.window {
+ return false
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if ack == s.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if s.sndCwnd < s.fr.maxCwnd {
+ s.sndCwnd++
+ }
+ return false
+ }
+
+ // A partial ack was received. Retransmit this packet and
+ // remember it so that we don't retransmit it again. We don't
+ // inflate the window because we're putting the same packet back
+ // onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ s.fr.first = ack
+ s.dupAckCount = 0
+ return true
+}
+
+// isAssignedSequenceNumber relies on the fact that we only set flags once a
+// sequencenumber is assigned and that is only done right before we send the
+// segment. As a result any segment that has a non-zero flag has a valid
+// sequence number assigned to it.
+func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
+ return seg.flags != 0
+}
+
+// SetPipe implements the SetPipe() function described in RFC6675. Netstack
+// maintains the congestion window in number of packets and not bytes, so
+// SetPipe() here measures number of outstanding packets rather than actual
+// outstanding bytes in the network.
+func (s *sender) SetPipe() {
+ // If SACK isn't permitted or it is permitted but recovery is not active
+ // then ignore pipe calculations.
+ if !s.ep.sackPermitted || !s.fr.active {
+ return
+ }
+ pipe := 0
+ smss := seqnum.Size(s.ep.scoreboard.SMSS())
+ for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() {
+ // With GSO each segment can be much larger than SMSS. So check the segment
+ // in SMSS sized ranges.
+ segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size()))
+ for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) {
+ endSeq := startSeq.Add(smss)
+ if segEnd.LessThan(endSeq) {
+ endSeq = segEnd
+ }
+ sb := header.SACKBlock{startSeq, endSeq}
+ // SetPipe():
+ //
+ // After initializing pipe to zero, the following steps are
+ // taken for each octet 'S1' in the sequence space between
+ // HighACK and HighData that has not been SACKed:
+ if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ break
+ }
+ if s.ep.scoreboard.IsSACKED(sb) {
+ continue
+ }
+
+ // SetPipe():
+ //
+ // (a) If IsLost(S1) returns false, Pipe is incremened by 1.
+ //
+ // NOTE: here we mark the whole segment as lost. We do not try
+ // and test every byte in our write buffer as we maintain our
+ // pipe in terms of oustanding packets and not bytes.
+ if !s.ep.scoreboard.IsRangeLost(sb) {
+ pipe++
+ }
+ // SetPipe():
+ // (b) If S1 <= HighRxt, Pipe is incremented by 1.
+ if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ pipe++
+ }
+ }
+ }
+ s.outstanding = pipe
+}
+
+// checkDuplicateAck is called when an ack is received. It manages the state
+// related to duplicate acks and determines if a retransmit is needed according
+// to the rules in RFC 6582 (NewReno).
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ if s.fr.active {
+ return s.handleFastRecovery(seg)
+ }
+
+ // We're not in fast recovery yet. A segment is considered a duplicate
+ // only if it doesn't carry any data and doesn't update the send window,
+ // because if it does, it wasn't sent in response to an out-of-order
+ // segment. If SACK is enabled then we have an additional check to see
+ // if the segment carries new SACK information. If it does then it is
+ // considered a duplicate ACK as per RFC6675.
+ if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
+ if !s.ep.sackPermitted || !seg.hasNewSACKInfo {
+ s.dupAckCount = 0
+ return false
+ }
+ }
+
+ s.dupAckCount++
+
+ // Do not enter fast recovery until we reach nDupAckThreshold or the
+ // first unacknowledged byte is considered lost as per SACK scoreboard.
+ if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) {
+ // RFC 6675 Step 3.
+ s.fr.highRxt = s.sndUna - 1
+ // Do run SetPipe() to calculate the outstanding segments.
+ s.SetPipe()
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
+ //
+ // We only do the check here, the incrementing of last to the highest
+ // sequence number transmitted till now is done when enterFastRecovery
+ // is invoked.
+ if !s.fr.last.LessThan(seg.ackNumber) {
+ s.dupAckCount = 0
+ return false
+ }
+ s.cc.HandleNDupAcks()
+ s.enterFastRecovery()
+ s.dupAckCount = 0
+ return true
+}
+
+// handleRcvdSegment is called when a segment is received; it is responsible for
+// updating the send-related state.
+func (s *sender) handleRcvdSegment(seg *segment) {
+ // Check if we can extract an RTT measurement from this ack.
+ if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.rttMeasureTime))
+ s.rttMeasureSeqNum = s.sndNxt
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if s.ep.sendTSOk && seg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ }
+
+ // Insert SACKBlock information into our scoreboard.
+ if s.ep.sackPermitted {
+ for _, sb := range seg.parsedOptions.SACKBlocks {
+ // Only insert the SACK block if the following holds
+ // true:
+ // * SACK block acks data after the ack number in the
+ // current segment.
+ // * SACK block represents a sequence
+ // between sndUna and sndNxt (i.e. data that is
+ // currently unacked and in-flight).
+ // * SACK block that has not been SACKed already.
+ //
+ // NOTE: This check specifically excludes DSACK blocks
+ // which have start/end before sndUna and are used to
+ // indicate spurious retransmissions.
+ if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ s.ep.scoreboard.Insert(sb)
+ seg.hasNewSACKInfo = true
+ }
+ }
+ s.SetPipe()
+ }
+
+ // Count the duplicates and do the fast retransmit if needed.
+ rtx := s.checkDuplicateAck(seg)
+
+ // Stash away the current window size.
+ s.sndWnd = seg.window
+
+ // Ignore ack if it doesn't acknowledge any new data.
+ ack := seg.ackNumber
+ if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ s.dupAckCount = 0
+
+ // See : https://tools.ietf.org/html/rfc1323#section-3.3.
+ // Specifically we should only update the RTO using TSEcr if the
+ // following condition holds:
+ //
+ // A TSecr value received in a segment is used to update the
+ // averaged RTT measurement only if the segment acknowledges
+ // some new data, i.e., only if it advances the left edge of
+ // the send window.
+ if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ // TSVal/Ecr values sent by Netstack are at a millisecond
+ // granularity.
+ elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ s.updateRTO(elapsed)
+ }
+
+ // When an ack is received we must rearm the timer.
+ // RFC 6298 5.2
+ s.resendTimer.enable(s.rto)
+
+ // Remove all acknowledged data from the write list.
+ acked := s.sndUna.Size(ack)
+ s.sndUna = ack
+
+ ackLeft := acked
+ originalOutstanding := s.outstanding
+ for ackLeft > 0 {
+ // We use logicalLen here because we can have FIN
+ // segments (which are always at the end of list) that
+ // have no data, but do consume a sequence number.
+ seg := s.writeList.Front()
+ datalen := seg.logicalLen()
+
+ if datalen > ackLeft {
+ prevCount := s.pCount(seg)
+ seg.data.TrimFront(int(ackLeft))
+ seg.sequenceNumber.UpdateForward(ackLeft)
+ s.outstanding -= prevCount - s.pCount(seg)
+ break
+ }
+
+ if s.writeNext == seg {
+ s.writeNext = seg.Next()
+ }
+ s.writeList.Remove(seg)
+
+ // if SACK is enabled then Only reduce outstanding if
+ // the segment was not previously SACKED as these have
+ // already been accounted for in SetPipe().
+ if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.outstanding -= s.pCount(seg)
+ }
+ seg.decRef()
+ ackLeft -= datalen
+ }
+
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
+ // Clear SACK information for all acked data.
+ s.ep.scoreboard.Delete(s.sndUna)
+
+ // If we are not in fast recovery then update the congestion
+ // window based on the number of acknowledged packets.
+ if !s.fr.active {
+ s.cc.Update(originalOutstanding - s.outstanding)
+ }
+
+ // It is possible for s.outstanding to drop below zero if we get
+ // a retransmit timeout, reset outstanding to zero but later
+ // get an ack that cover previously sent data.
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ s.SetPipe()
+
+ // If all outstanding data was acknowledged the disable the timer.
+ // RFC 6298 Rule 5.3
+ if s.sndUna == s.sndNxt {
+ s.outstanding = 0
+ s.resendTimer.disable()
+ }
+ }
+ // Now that we've popped all acknowledged data from the retransmit
+ // queue, retransmit if needed.
+ if rtx {
+ s.resendSegment()
+ }
+
+ // Send more data now that some of the pending data has been ack'd, or
+ // that the window opened up, or the congestion window was inflated due
+ // to a duplicate ack during fast recovery. This will also re-enable
+ // the retransmit timer if needed.
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ s.sendData()
+ }
+}
+
+// sendSegment sends the specified segment.
+func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+ if !seg.xmitTime.IsZero() {
+ s.ep.stack.Stats().TCP.Retransmits.Increment()
+ if s.sndCwnd < s.sndSsthresh {
+ s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
+ }
+ }
+ seg.xmitTime = time.Now()
+ return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+}
+
+// sendSegmentFromView sends a new segment containing the given payload, flags
+// and sequence number.
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+ s.lastSendTime = time.Now()
+ if seq == s.rttMeasureSeqNum {
+ s.rttMeasureTime = s.lastSendTime
+ }
+
+ rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
+
+ // Remember the max sent ack.
+ s.maxSentAck = rcvNxt
+
+ // Every time a packet containing data is sent (including a
+ // retransmission), if SACK is enabled then use the conservative timer
+ // described in RFC6675 Section 4.0, otherwise follow the standard time
+ // described in RFC6298 Section 5.2.
+ if data.Size() != 0 {
+ if s.ep.sackPermitted {
+ s.resendTimer.enable(s.rto)
+ } else {
+ if !s.resendTimer.enabled() {
+ s.resendTimer.enable(s.rto)
+ }
+ }
+ }
+
+ return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
+}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
new file mode 100644
index 000000000..12eff8afc
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -0,0 +1,50 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastSendTime is invoked by stateify.
+func (s *sender) saveLastSendTime() unixTime {
+ return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (s *sender) loadLastSendTime(unix unixTime) {
+ s.lastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (s *sender) saveRttMeasureTime() unixTime {
+ return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (s *sender) loadRttMeasureTime(unix unixTime) {
+ s.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// afterLoad is invoked by stateify.
+func (s *sender) afterLoad() {
+ s.resendTimer.init(&s.resendWaker)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go
new file mode 100755
index 000000000..029f98a11
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go
@@ -0,0 +1,173 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type segmentElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type segmentList struct {
+ head *segment
+ tail *segment
+}
+
+// Reset resets list l to the empty state.
+func (l *segmentList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *segmentList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *segmentList) Front() *segment {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *segmentList) Back() *segment {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *segmentList) PushFront(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(l.head)
+ segmentElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *segmentList) PushBack(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(nil)
+ segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *segmentList) PushBackList(m *segmentList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *segmentList) InsertAfter(b, e *segment) {
+ a := segmentElementMapper{}.linkerFor(b).Next()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *segmentList) InsertBefore(a, e *segment) {
+ b := segmentElementMapper{}.linkerFor(a).Prev()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *segmentList) Remove(e *segment) {
+ prev := segmentElementMapper{}.linkerFor(e).Prev()
+ next := segmentElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ segmentElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ segmentElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type segmentEntry struct {
+ next *segment
+ prev *segment
+}
+
+// Next returns the entry that follows e in the list.
+func (e *segmentEntry) Next() *segment {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *segmentEntry) Prev() *segment {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *segmentEntry) SetNext(elem *segment) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *segmentEntry) SetPrev(elem *segment) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
new file mode 100755
index 000000000..9049a99b2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -0,0 +1,400 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *SACKInfo) beforeSave() {}
+func (x *SACKInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Blocks", &x.Blocks)
+ m.Save("NumBlocks", &x.NumBlocks)
+}
+
+func (x *SACKInfo) afterLoad() {}
+func (x *SACKInfo) load(m state.Map) {
+ m.Load("Blocks", &x.Blocks)
+ m.Load("NumBlocks", &x.NumBlocks)
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var lastError string = x.saveLastError()
+ m.SaveValue("lastError", lastError)
+ var state endpointState = x.saveState()
+ m.SaveValue("state", state)
+ var hardError string = x.saveHardError()
+ m.SaveValue("hardError", hardError)
+ var acceptedChan []*endpoint = x.saveAcceptedChan()
+ m.SaveValue("acceptedChan", acceptedChan)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvBufUsed", &x.rcvBufUsed)
+ m.Save("id", &x.id)
+ m.Save("isRegistered", &x.isRegistered)
+ m.Save("v6only", &x.v6only)
+ m.Save("isConnectNotified", &x.isConnectNotified)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("workerRunning", &x.workerRunning)
+ m.Save("workerCleanup", &x.workerCleanup)
+ m.Save("sendTSOk", &x.sendTSOk)
+ m.Save("recentTS", &x.recentTS)
+ m.Save("tsOffset", &x.tsOffset)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("sackPermitted", &x.sackPermitted)
+ m.Save("sack", &x.sack)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("delay", &x.delay)
+ m.Save("cork", &x.cork)
+ m.Save("scoreboard", &x.scoreboard)
+ m.Save("reuseAddr", &x.reuseAddr)
+ m.Save("slowAck", &x.slowAck)
+ m.Save("segmentQueue", &x.segmentQueue)
+ m.Save("synRcvdCount", &x.synRcvdCount)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("sndBufUsed", &x.sndBufUsed)
+ m.Save("sndClosed", &x.sndClosed)
+ m.Save("sndBufInQueue", &x.sndBufInQueue)
+ m.Save("sndQueue", &x.sndQueue)
+ m.Save("cc", &x.cc)
+ m.Save("packetTooBigCount", &x.packetTooBigCount)
+ m.Save("sndMTU", &x.sndMTU)
+ m.Save("keepalive", &x.keepalive)
+ m.Save("rcv", &x.rcv)
+ m.Save("snd", &x.snd)
+ m.Save("bindAddress", &x.bindAddress)
+ m.Save("connectingAddress", &x.connectingAddress)
+ m.Save("gso", &x.gso)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.LoadWait("waiterQueue", &x.waiterQueue)
+ m.LoadWait("rcvList", &x.rcvList)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvBufUsed", &x.rcvBufUsed)
+ m.Load("id", &x.id)
+ m.Load("isRegistered", &x.isRegistered)
+ m.Load("v6only", &x.v6only)
+ m.Load("isConnectNotified", &x.isConnectNotified)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("workerRunning", &x.workerRunning)
+ m.Load("workerCleanup", &x.workerCleanup)
+ m.Load("sendTSOk", &x.sendTSOk)
+ m.Load("recentTS", &x.recentTS)
+ m.Load("tsOffset", &x.tsOffset)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("sackPermitted", &x.sackPermitted)
+ m.Load("sack", &x.sack)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("delay", &x.delay)
+ m.Load("cork", &x.cork)
+ m.Load("scoreboard", &x.scoreboard)
+ m.Load("reuseAddr", &x.reuseAddr)
+ m.Load("slowAck", &x.slowAck)
+ m.LoadWait("segmentQueue", &x.segmentQueue)
+ m.Load("synRcvdCount", &x.synRcvdCount)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("sndBufUsed", &x.sndBufUsed)
+ m.Load("sndClosed", &x.sndClosed)
+ m.Load("sndBufInQueue", &x.sndBufInQueue)
+ m.LoadWait("sndQueue", &x.sndQueue)
+ m.Load("cc", &x.cc)
+ m.Load("packetTooBigCount", &x.packetTooBigCount)
+ m.Load("sndMTU", &x.sndMTU)
+ m.Load("keepalive", &x.keepalive)
+ m.LoadWait("rcv", &x.rcv)
+ m.LoadWait("snd", &x.snd)
+ m.Load("bindAddress", &x.bindAddress)
+ m.Load("connectingAddress", &x.connectingAddress)
+ m.Load("gso", &x.gso)
+ m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) })
+ m.LoadValue("state", new(endpointState), func(y interface{}) { x.loadState(y.(endpointState)) })
+ m.LoadValue("hardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) })
+ m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *keepalive) beforeSave() {}
+func (x *keepalive) save(m state.Map) {
+ x.beforeSave()
+ m.Save("enabled", &x.enabled)
+ m.Save("idle", &x.idle)
+ m.Save("interval", &x.interval)
+ m.Save("count", &x.count)
+ m.Save("unacked", &x.unacked)
+}
+
+func (x *keepalive) afterLoad() {}
+func (x *keepalive) load(m state.Map) {
+ m.Load("enabled", &x.enabled)
+ m.Load("idle", &x.idle)
+ m.Load("interval", &x.interval)
+ m.Load("count", &x.count)
+ m.Load("unacked", &x.unacked)
+}
+
+func (x *receiver) beforeSave() {}
+func (x *receiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ep", &x.ep)
+ m.Save("rcvNxt", &x.rcvNxt)
+ m.Save("rcvAcc", &x.rcvAcc)
+ m.Save("rcvWndScale", &x.rcvWndScale)
+ m.Save("closed", &x.closed)
+ m.Save("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Save("pendingBufUsed", &x.pendingBufUsed)
+ m.Save("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *receiver) afterLoad() {}
+func (x *receiver) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("rcvNxt", &x.rcvNxt)
+ m.Load("rcvAcc", &x.rcvAcc)
+ m.Load("rcvWndScale", &x.rcvWndScale)
+ m.Load("closed", &x.closed)
+ m.Load("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Load("pendingBufUsed", &x.pendingBufUsed)
+ m.Load("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *renoState) beforeSave() {}
+func (x *renoState) save(m state.Map) {
+ x.beforeSave()
+ m.Save("s", &x.s)
+}
+
+func (x *renoState) afterLoad() {}
+func (x *renoState) load(m state.Map) {
+ m.Load("s", &x.s)
+}
+
+func (x *SACKScoreboard) beforeSave() {}
+func (x *SACKScoreboard) save(m state.Map) {
+ x.beforeSave()
+ m.Save("smss", &x.smss)
+ m.Save("maxSACKED", &x.maxSACKED)
+}
+
+func (x *SACKScoreboard) afterLoad() {}
+func (x *SACKScoreboard) load(m state.Map) {
+ m.Load("smss", &x.smss)
+ m.Load("maxSACKED", &x.maxSACKED)
+}
+
+func (x *segment) beforeSave() {}
+func (x *segment) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ var options []byte = x.saveOptions()
+ m.SaveValue("options", options)
+ var rcvdTime unixTime = x.saveRcvdTime()
+ m.SaveValue("rcvdTime", rcvdTime)
+ var xmitTime unixTime = x.saveXmitTime()
+ m.SaveValue("xmitTime", xmitTime)
+ m.Save("segmentEntry", &x.segmentEntry)
+ m.Save("refCnt", &x.refCnt)
+ m.Save("viewToDeliver", &x.viewToDeliver)
+ m.Save("sequenceNumber", &x.sequenceNumber)
+ m.Save("ackNumber", &x.ackNumber)
+ m.Save("flags", &x.flags)
+ m.Save("window", &x.window)
+ m.Save("csum", &x.csum)
+ m.Save("csumValid", &x.csumValid)
+ m.Save("parsedOptions", &x.parsedOptions)
+ m.Save("hasNewSACKInfo", &x.hasNewSACKInfo)
+}
+
+func (x *segment) afterLoad() {}
+func (x *segment) load(m state.Map) {
+ m.Load("segmentEntry", &x.segmentEntry)
+ m.Load("refCnt", &x.refCnt)
+ m.Load("viewToDeliver", &x.viewToDeliver)
+ m.Load("sequenceNumber", &x.sequenceNumber)
+ m.Load("ackNumber", &x.ackNumber)
+ m.Load("flags", &x.flags)
+ m.Load("window", &x.window)
+ m.Load("csum", &x.csum)
+ m.Load("csumValid", &x.csumValid)
+ m.Load("parsedOptions", &x.parsedOptions)
+ m.Load("hasNewSACKInfo", &x.hasNewSACKInfo)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+ m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) })
+ m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) })
+ m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) })
+}
+
+func (x *segmentQueue) beforeSave() {}
+func (x *segmentQueue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("list", &x.list)
+ m.Save("limit", &x.limit)
+ m.Save("used", &x.used)
+}
+
+func (x *segmentQueue) afterLoad() {}
+func (x *segmentQueue) load(m state.Map) {
+ m.LoadWait("list", &x.list)
+ m.Load("limit", &x.limit)
+ m.Load("used", &x.used)
+}
+
+func (x *sender) beforeSave() {}
+func (x *sender) save(m state.Map) {
+ x.beforeSave()
+ var lastSendTime unixTime = x.saveLastSendTime()
+ m.SaveValue("lastSendTime", lastSendTime)
+ var rttMeasureTime unixTime = x.saveRttMeasureTime()
+ m.SaveValue("rttMeasureTime", rttMeasureTime)
+ m.Save("ep", &x.ep)
+ m.Save("dupAckCount", &x.dupAckCount)
+ m.Save("fr", &x.fr)
+ m.Save("sndCwnd", &x.sndCwnd)
+ m.Save("sndSsthresh", &x.sndSsthresh)
+ m.Save("sndCAAckCount", &x.sndCAAckCount)
+ m.Save("outstanding", &x.outstanding)
+ m.Save("sndWnd", &x.sndWnd)
+ m.Save("sndUna", &x.sndUna)
+ m.Save("sndNxt", &x.sndNxt)
+ m.Save("sndNxtList", &x.sndNxtList)
+ m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Save("closed", &x.closed)
+ m.Save("writeNext", &x.writeNext)
+ m.Save("writeList", &x.writeList)
+ m.Save("rtt", &x.rtt)
+ m.Save("rto", &x.rto)
+ m.Save("srttInited", &x.srttInited)
+ m.Save("maxPayloadSize", &x.maxPayloadSize)
+ m.Save("gso", &x.gso)
+ m.Save("sndWndScale", &x.sndWndScale)
+ m.Save("maxSentAck", &x.maxSentAck)
+ m.Save("cc", &x.cc)
+}
+
+func (x *sender) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("dupAckCount", &x.dupAckCount)
+ m.Load("fr", &x.fr)
+ m.Load("sndCwnd", &x.sndCwnd)
+ m.Load("sndSsthresh", &x.sndSsthresh)
+ m.Load("sndCAAckCount", &x.sndCAAckCount)
+ m.Load("outstanding", &x.outstanding)
+ m.Load("sndWnd", &x.sndWnd)
+ m.Load("sndUna", &x.sndUna)
+ m.Load("sndNxt", &x.sndNxt)
+ m.Load("sndNxtList", &x.sndNxtList)
+ m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Load("closed", &x.closed)
+ m.Load("writeNext", &x.writeNext)
+ m.Load("writeList", &x.writeList)
+ m.Load("rtt", &x.rtt)
+ m.Load("rto", &x.rto)
+ m.Load("srttInited", &x.srttInited)
+ m.Load("maxPayloadSize", &x.maxPayloadSize)
+ m.Load("gso", &x.gso)
+ m.Load("sndWndScale", &x.sndWndScale)
+ m.Load("maxSentAck", &x.maxSentAck)
+ m.Load("cc", &x.cc)
+ m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) })
+ m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *rtt) beforeSave() {}
+func (x *rtt) save(m state.Map) {
+ x.beforeSave()
+ m.Save("srtt", &x.srtt)
+ m.Save("rttvar", &x.rttvar)
+}
+
+func (x *rtt) afterLoad() {}
+func (x *rtt) load(m state.Map) {
+ m.Load("srtt", &x.srtt)
+ m.Load("rttvar", &x.rttvar)
+}
+
+func (x *fastRecovery) beforeSave() {}
+func (x *fastRecovery) save(m state.Map) {
+ x.beforeSave()
+ m.Save("active", &x.active)
+ m.Save("first", &x.first)
+ m.Save("last", &x.last)
+ m.Save("maxCwnd", &x.maxCwnd)
+ m.Save("highRxt", &x.highRxt)
+ m.Save("rescueRxt", &x.rescueRxt)
+}
+
+func (x *fastRecovery) afterLoad() {}
+func (x *fastRecovery) load(m state.Map) {
+ m.Load("active", &x.active)
+ m.Load("first", &x.first)
+ m.Load("last", &x.last)
+ m.Load("maxCwnd", &x.maxCwnd)
+ m.Load("highRxt", &x.highRxt)
+ m.Load("rescueRxt", &x.rescueRxt)
+}
+
+func (x *unixTime) beforeSave() {}
+func (x *unixTime) save(m state.Map) {
+ x.beforeSave()
+ m.Save("second", &x.second)
+ m.Save("nano", &x.nano)
+}
+
+func (x *unixTime) afterLoad() {}
+func (x *unixTime) load(m state.Map) {
+ m.Load("second", &x.second)
+ m.Load("nano", &x.nano)
+}
+
+func (x *segmentList) beforeSave() {}
+func (x *segmentList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *segmentList) afterLoad() {}
+func (x *segmentList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *segmentEntry) beforeSave() {}
+func (x *segmentEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *segmentEntry) afterLoad() {}
+func (x *segmentEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load})
+ state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load})
+ state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load})
+ state.Register("tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load})
+ state.Register("tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load})
+ state.Register("tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load})
+ state.Register("tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load})
+ state.Register("tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load})
+ state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load})
+ state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load})
+ state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load})
+ state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load})
+ state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
new file mode 100644
index 000000000..fc1c7cbd2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+)
+
+type timerState int
+
+const (
+ timerStateDisabled timerState = iota
+ timerStateEnabled
+ timerStateOrphaned
+)
+
+// timer is a timer implementation that reduces the interactions with the
+// runtime timer infrastructure by letting timers run (and potentially
+// eventually expire) even if they are stopped. It makes it cheaper to
+// disable/reenable timers at the expense of spurious wakes. This is useful for
+// cases when the same timer is disabled/reenabled repeatedly with relatively
+// long timeouts farther into the future.
+//
+// TCP retransmit timers benefit from this because they the timeouts are long
+// (currently at least 200ms), and get disabled when acks are received, and
+// reenabled when new pending segments are sent.
+//
+// It is advantageous to avoid interacting with the runtime because it acquires
+// a global mutex and performs O(log n) operations, where n is the global number
+// of timers, whenever a timer is enabled or disabled, and may make a syscall.
+//
+// This struct is thread-compatible.
+type timer struct {
+ // state is the current state of the timer, it can be one of the
+ // following values:
+ // disabled - the timer is disabled.
+ // orphaned - the timer is disabled, but the runtime timer is
+ // enabled, which means that it will evetually cause a
+ // spurious wake (unless it gets enabled again before
+ // then).
+ // enabled - the timer is enabled, but the runtime timer may be set
+ // to an earlier expiration time due to a previous
+ // orphaned state.
+ state timerState
+
+ // target is the expiration time of the current timer. It is only
+ // meaningful in the enabled state.
+ target time.Time
+
+ // runtimeTarget is the expiration time of the runtime timer. It is
+ // meaningful in the enabled and orphaned states.
+ runtimeTarget time.Time
+
+ // timer is the runtime timer used to wait on.
+ timer *time.Timer
+}
+
+// init initializes the timer. Once it expires, it the given waker will be
+// asserted.
+func (t *timer) init(w *sleep.Waker) {
+ t.state = timerStateDisabled
+
+ // Initialize a runtime timer that will assert the waker, then
+ // immediately stop it.
+ t.timer = time.AfterFunc(time.Hour, func() {
+ w.Assert()
+ })
+ t.timer.Stop()
+}
+
+// cleanup frees all resources associated with the timer.
+func (t *timer) cleanup() {
+ t.timer.Stop()
+}
+
+// checkExpiration checks if the given timer has actually expired, it should be
+// called whenever a sleeper wakes up due to the waker being asserted, and is
+// used to check if it's a supurious wake (due to a previously orphaned timer)
+// or a legitimate one.
+func (t *timer) checkExpiration() bool {
+ // Transition to fully disabled state if we're just consuming an
+ // orphaned timer.
+ if t.state == timerStateOrphaned {
+ t.state = timerStateDisabled
+ return false
+ }
+
+ // The timer is enabled, but it may have expired early. Check if that's
+ // the case, and if so, reset the runtime timer to the correct time.
+ now := time.Now()
+ if now.Before(t.target) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(t.target.Sub(now))
+ return false
+ }
+
+ // The timer has actually expired, disable it for now and inform the
+ // caller.
+ t.state = timerStateDisabled
+ return true
+}
+
+// disable disables the timer, leaving it in an orphaned state if it wasn't
+// already disabled.
+func (t *timer) disable() {
+ if t.state != timerStateDisabled {
+ t.state = timerStateOrphaned
+ }
+}
+
+// enabled returns true if the timer is currently enabled, false otherwise.
+func (t *timer) enabled() bool {
+ return t.state == timerStateEnabled
+}
+
+// enable enables the timer, programming the runtime timer if necessary.
+func (t *timer) enable(d time.Duration) {
+ t.target = time.Now().Add(d)
+
+ // Check if we need to set the runtime timer.
+ if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(d)
+ }
+
+ t.state = timerStateEnabled
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
new file mode 100644
index 000000000..3d52a4f31
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -0,0 +1,1002 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "math"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type udpPacket struct {
+ udpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents a UDP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// It implements tcpip.Endpoint.
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
+ multicastAddr tcpip.Address
+ multicastNICID tcpip.NICID
+ multicastLoop bool
+ reusePort bool
+ broadcast bool
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ }
+
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// connectRoute establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return stack.Route{}, 0, 0, err
+ }
+
+ localAddr := e.id.LocalAddress
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicid == 0 {
+ nicid = e.multicastNICID
+ }
+ if localAddr == "" {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
+ if err != nil {
+ return stack.Route{}, 0, 0, err
+ }
+ return r, nicid, netProto, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ if p.Size() > math.MaxUint16 {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may need to
+ // change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ if to.Addr == header.IPv4Broadcast && !e.broadcast {
+ return 0, nil, tcpip.ErrBroadcastDisabled
+ }
+
+ r, _, _, err := e.connectRoute(nicid, *to)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ dstPort = to.Port
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ ttl := route.DefaultTTL()
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ }
+
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ netProto, err := e.checkV4Mapped(&fa, false)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return tcpip.ErrBadLocalAddress
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ if e.bindNICID != 0 && e.bindNICID != nic {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+
+ case tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for i, mem := range e.multicastMemberships {
+ if mem == memToRemove {
+ memToRemoveIndex = i
+ break
+ }
+ }
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = bool(v)
+ e.mu.Unlock()
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v != 0
+ e.mu.Unlock()
+
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ e.multicastNICID,
+ e.multicastAddr,
+ }
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+
+ *o = tcpip.MulticastLoopOption(v)
+ return nil
+
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ case *tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// sendUDP sends a UDP segment via the provided network endpoint and under the
+// provided identity.
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+ // Allocate a buffer for the UDP header.
+ hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+
+ // Initialize the header.
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: localPort,
+ DstPort: remotePort,
+ Length: length,
+ })
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udp.SetChecksum(^udp.CalculateChecksum(xsum))
+ }
+
+ // Track count of packets sent.
+ r.Stats().UDP.PacketsSent.Increment()
+
+ return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl)
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+
+ // Fail if we are bound to an IPv6 address.
+ if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ return 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Port == 0 {
+ // We don't support connecting to port zero.
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ var localPort uint16
+ switch e.state {
+ case stateInitial:
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ r, nicid, netProto, err := e.connectRoute(nicid, addr)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemotePort: addr.Port,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
+ }
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ // Remove the old registration.
+ if e.id.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.dstPort = addr.Port
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ e.shutdownFlags |= flags
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if e.id.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ if err != nil {
+ return id, err
+ }
+ id.LocalPort = port
+ }
+
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ }
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, true)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ nicid := addr.NIC
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicid == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Save the effective NICID generated by bindLocked.
+ e.bindNICID = e.regNICID
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ vv.TrimFront(header.UDPMinimumSize)
+
+ e.rcvMu.Lock()
+ e.stack.Stats().UDP.PacketsReceived.Increment()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &udpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ Port: hdr.SourcePort(),
+ },
+ }
+ pkt.data = vv.Clone(pkt.views[:])
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += vv.Size()
+
+ pkt.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
new file mode 100644
index 000000000..74e8e9fd5
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves udpPacket.data field.
+func (u *udpPacket) saveData() buffer.VectorisedView {
+ // We cannot save u.data directly as u.data.views may alias to u.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return u.data.Clone(nil)
+}
+
+// loadData loads udpPacket.data field.
+func (u *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing u.views for data.views.
+ u.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ id := e.id
+ e.id.LocalPort = 0
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
new file mode 100644
index 000000000..25bdd2929
--- /dev/null
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -0,0 +1,96 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a session request forwarder, which allows clients to decide
+// what to do with a session request, for example: ignore it, or process it.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ handler func(*ForwarderRequest)
+
+ stack *stack.Stack
+}
+
+// NewForwarder allocates and initializes a new forwarder.
+func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
+ return &Forwarder{
+ stack: s,
+ handler: handler,
+ }
+}
+
+// HandlePacket handles all packets.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ f.handler(&ForwarderRequest{
+ stack: f.stack,
+ route: r,
+ id: id,
+ vv: vv,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a session request received by the forwarder and
+// passed to the client. Clients may optionally create an endpoint to represent
+// it via CreateEndpoint.
+type ForwarderRequest struct {
+ stack *stack.Stack
+ route *stack.Route
+ id stack.TransportEndpointID
+ vv buffer.VectorisedView
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the session request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.id
+}
+
+// CreateEndpoint creates a connected UDP endpoint for the session request.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(r.stack, r.route.NetProto, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.id = r.id
+ ep.route = r.route.Clone()
+ ep.dstPort = r.id.RemotePort
+ ep.regNICID = r.route.NICID()
+
+ ep.state = stateConnected
+
+ ep.rcvMu.Lock()
+ ep.rcvReady = true
+ ep.rcvMu.Unlock()
+
+ ep.HandlePacket(r.route, r.id, r.vv)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
new file mode 100644
index 000000000..3d31dfbf1
--- /dev/null
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package udp contains the implementation of the UDP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing udp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the udp protocol name.
+ ProtocolName = "udp"
+
+ // ProtocolNumber is the udp protocol number.
+ ProtocolNumber = header.UDPProtocolNumber
+)
+
+type protocol struct{}
+
+// Number returns the udp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new udp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw UDP endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid udp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.UDPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given udp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.UDP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100755
index 000000000..673a9373b
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,173 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ a := udpPacketElementMapper{}.linkerFor(b).Next()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ b := udpPacketElementMapper{}.linkerFor(a).Prev()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *udpPacketList) Remove(e *udpPacket) {
+ prev := udpPacketElementMapper{}.linkerFor(e).Prev()
+ next := udpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100755
index 000000000..711e2feeb
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,128 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *udpPacket) beforeSave() {}
+func (x *udpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("udpPacketEntry", &x.udpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *udpPacket) afterLoad() {}
+func (x *udpPacket) load(m state.Map) {
+ m.Load("udpPacketEntry", &x.udpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("id", &x.id)
+ m.Save("state", &x.state)
+ m.Save("bindNICID", &x.bindNICID)
+ m.Save("regNICID", &x.regNICID)
+ m.Save("dstPort", &x.dstPort)
+ m.Save("v6only", &x.v6only)
+ m.Save("multicastTTL", &x.multicastTTL)
+ m.Save("multicastAddr", &x.multicastAddr)
+ m.Save("multicastNICID", &x.multicastNICID)
+ m.Save("multicastLoop", &x.multicastLoop)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("multicastMemberships", &x.multicastMemberships)
+ m.Save("effectiveNetProtos", &x.effectiveNetProtos)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("id", &x.id)
+ m.Load("state", &x.state)
+ m.Load("bindNICID", &x.bindNICID)
+ m.Load("regNICID", &x.regNICID)
+ m.Load("dstPort", &x.dstPort)
+ m.Load("v6only", &x.v6only)
+ m.Load("multicastTTL", &x.multicastTTL)
+ m.Load("multicastAddr", &x.multicastAddr)
+ m.Load("multicastNICID", &x.multicastNICID)
+ m.Load("multicastLoop", &x.multicastLoop)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("multicastMemberships", &x.multicastMemberships)
+ m.Load("effectiveNetProtos", &x.effectiveNetProtos)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *multicastMembership) beforeSave() {}
+func (x *multicastMembership) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nicID", &x.nicID)
+ m.Save("multicastAddr", &x.multicastAddr)
+}
+
+func (x *multicastMembership) afterLoad() {}
+func (x *multicastMembership) load(m state.Map) {
+ m.Load("nicID", &x.nicID)
+ m.Load("multicastAddr", &x.multicastAddr)
+}
+
+func (x *udpPacketList) beforeSave() {}
+func (x *udpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *udpPacketList) afterLoad() {}
+func (x *udpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *udpPacketEntry) beforeSave() {}
+func (x *udpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *udpPacketEntry) afterLoad() {}
+func (x *udpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load})
+ state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load})
+ state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load})
+ state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load})
+}