summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/header/icmpv4.go1
-rw-r--r--pkg/tcpip/header/icmpv6.go11
-rw-r--r--pkg/tcpip/link/channel/channel.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go36
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go8
-rw-r--r--pkg/tcpip/link/loopback/loopback.go3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go4
-rw-r--r--pkg/tcpip/link/nested/nested.go15
-rw-r--r--pkg/tcpip/link/nested/nested_test.go4
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go21
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/link/tun/device.go43
-rw-r--r--pkg/tcpip/link/waitable/waitable.go14
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/arp/arp.go7
-rw-r--r--pkg/tcpip/network/arp/arp_test.go58
-rw-r--r--pkg/tcpip/network/ip_test.go5
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go3
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go10
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go52
-rw-r--r--pkg/tcpip/stack/conntrack.go136
-rw-r--r--pkg/tcpip/stack/forwarder_test.go23
-rw-r--r--pkg/tcpip/stack/iptables.go173
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/iptables_types.go22
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/ndp.go140
-rw-r--r--pkg/tcpip/stack/ndp_test.go28
-rw-r--r--pkg/tcpip/stack/nic.go30
-rw-r--r--pkg/tcpip/stack/nic_test.go7
-rw-r--r--pkg/tcpip/stack/packet_buffer.go4
-rw-r--r--pkg/tcpip/stack/registration.go28
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/tcpip.go52
-rw-r--r--pkg/tcpip/time_unsafe.go30
-rw-r--r--pkg/tcpip/timer.go147
-rw-r--r--pkg/tcpip/timer_test.go91
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go54
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/connect.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go26
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment.go6
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
54 files changed, 1046 insertions, 414 deletions
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 7908c5744..1a631b31a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -72,6 +72,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4TTLExceeded = 0
+ ICMPv4HostUnreachable = 1
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index c7ee2de57..a13b4b809 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -110,9 +110,16 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// Values for ICMP destination unreachable code as defined in RFC 4443 section
+// 3.1.
const (
- ICMPv6PortUnreachable = 4
+ ICMPv6NetworkUnreachable = 0
+ ICMPv6Prohibited = 1
+ ICMPv6BeyondScope = 2
+ ICMPv6AddressUnreachable = 3
+ ICMPv6PortUnreachable = 4
+ ICMPv6Policy = 5
+ ICMPv6RejectRoute = 6
)
// Type is the ICMP type field.
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a2bb773d4..e12a5929b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -302,3 +302,7 @@ func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 6aa1badc7..c18bb91fb 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -386,26 +386,33 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
var builder iovec.Builder
@@ -448,22 +455,8 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var eth header.Ethernet
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- eth = make(header.Ethernet, header.EthernetMinimumSize)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -493,7 +486,6 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
var builder iovec.Builder
builder.Add(vnetHdrBuf)
- builder.Add(eth)
builder.Add(pkt.Header.View())
for _, v := range pkt.Data.Views() {
builder.Add(v)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 4bad930c7..7b995b85a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -510,6 +514,10 @@ func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAdd
d.pkts = append(d.pkts, pkt)
}
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestDispatchPacketFormat(t *testing.T) {
for _, test := range []struct {
name string
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 3b17d8c28..781cdd317 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -118,3 +118,6 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c305d9e86..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -135,6 +135,10 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unsupported operation")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 328bd048e..d40de54df 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -61,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -135,3 +145,8 @@ func (e *Endpoint) GSOMaxSize() uint32 {
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
index c1a219f02..7d9249c1c 100644
--- a/pkg/tcpip/link/nested/nested_test.go
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd
d.count++
}
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNestedLinkEndpoint(t *testing.T) {
const emptyAddress = tcpip.LinkAddress("")
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index c84fe1bb9..467083239 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -107,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
@@ -194,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -213,3 +220,8 @@ func (e *endpoint) Wait() {
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index a36862c67..507c76b76 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
v := pkt.Data.ToView()
// Transmit the packet.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 28a2e88ba..8f3cd9449 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index d9cd4e83a..509076643 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 47446efec..04ae58e59 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -272,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader)
}
// Append upper headers.
@@ -366,3 +354,30 @@ func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
}
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 24a8dc2eb..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -60,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -153,3 +162,8 @@ func (e *Endpoint) Wait() {}
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ffb2354be..c448a888f 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -40,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress,
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -90,6 +94,11 @@ func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b0f57040c..31a242482 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -160,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: header.EthernetBroadcastAddress,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 66e67429c..a35a64a0f 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -32,10 +32,14 @@ import (
)
const (
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
type testContext struct {
@@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext {
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
@@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: stackLinkAddr2,
+ expectLinkAddr: stackLinkAddr2,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: header.EthernetBroadcastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ p := arp.NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a5b780ca2..615bae648 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -185,6 +185,11 @@ func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 1b67aa066..83e71cb8c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 3b79749b5..24600d877 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
@@ -502,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
@@ -511,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ RemoteLinkAddress: remoteLinkAddr,
}
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ }
+
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 52a01b44e..f86aaed1d 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -34,6 +34,9 @@ const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
var (
@@ -257,8 +260,7 @@ func newTestContext(t *testing.T) *testContext {
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
@@ -271,7 +273,7 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -951,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: mcaddr,
+ },
+ }
+
+ for _, test := range tests {
+ p := NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index d39baf620..559a1c4dd 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -49,7 +49,8 @@ const (
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
@@ -113,13 +114,11 @@ type conn struct {
// update the state of tcb. It is immutable.
tcbHook Hook
- // mu protects tcb.
+ // mu protects all mutable state.
mu sync.Mutex `state:"nosave"`
-
// tcb is TCB control block. It is used to keep track of states
// of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
-
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection. It is protected by mu.
lastUsed time.Time `state:".(unixTime)"`
@@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool {
return now.Sub(cn.lastUsed) > defaultTimeout
}
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
+}
+
// ConnTrack tracks all connections created for NAT rules. Most users are
-// expected to only call handlePacket and createConnFor.
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
//
// ConnTrack keeps all connections in a slice of buckets, each of which holds a
// linked list of tuples. This gives us some desirable properties:
@@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
return nil, dirOriginal
}
-// createConnFor creates a new conn for pkt.
-func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
manip = manipDstOutput
}
conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
// Lock the buckets in the correct order.
- tupleBucket := ct.bucket(tid)
- replyBucket := ct.bucket(replyTID)
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
ct.mu.RLock()
defer ct.mu.RUnlock()
if tupleBucket < replyBucket {
@@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
ct.buckets[tupleBucket].mu.Lock()
}
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
ct.buckets[replyBucket].mu.Unlock()
}
-
- return conn
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
}
// handlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
if conn == nil {
- // Connection not found for the packet or the packet is invalid.
- return
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return false
}
switch hook {
@@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else if hook == conn.tcbHook {
- conn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
- conn.tcb.UpdateStateInbound(tcpHeader)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
+ }
+
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
}
// bucket gets the conntrack bucket for a tupleID.
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index eefb4b07f..c962693f5 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -121,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {}
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
@@ -174,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
if f.addrCache != nil && f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr)
+ f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
})
}
return nil
@@ -307,6 +309,11 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
@@ -400,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -458,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
@@ -510,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -554,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -611,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index bbf3b60e8..cbbae4224 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,22 +22,30 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -48,11 +56,9 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() *IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
return &IPTables{
- tables: map[string]Table{
- TablenameNat: Table{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -60,60 +66,66 @@ func DefaultTables() *IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -127,51 +139,48 @@ func DefaultTables() *IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// GetTable returns table by name.
+// GetTable returns a table by name.
func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
it.mu.RLock()
defer it.mu.RUnlock()
- t, ok := it.tables[name]
- return t, ok
+ return it.tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) {
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -181,14 +190,8 @@ func (it *IPTables) ReplaceTable(name string, table Table) {
it.startReaper(reaperDelay)
}
it.modified = true
- it.tables[name] = table
-}
-
-// GetPriorities returns slice of priorities associated with hook.
-func (it *IPTables) GetPriorities(hook Hook) []string {
- it.mu.RLock()
- defer it.mu.RUnlock()
- return it.priorities[hook]
+ it.tables[id] = table
+ return nil
}
// A chainVerdict is what a table decides should be done with a packet.
@@ -223,11 +226,19 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.GetPriorities(hook) {
- table, _ := it.GetTable(tablename)
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -256,6 +267,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d43f60c67..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.createConnFor(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index eb70e3104..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -84,14 +84,14 @@ type IPTables struct {
// mu protects tables, priorities, and modified.
mu sync.RWMutex
- // tables maps table names to tables. User tables have arbitrary names.
- // mu needs to be locked for accessing.
- tables map[string]Table
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. mu needs to be locked for accessing.
- priorities map[Hook][]string
+ priorities [NumHooks][]tableID
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
@@ -113,22 +113,20 @@ type Table struct {
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 403557fd7..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 1baa498d0..b15b8d1cb 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index e28c23d66..9dce11a97 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -469,7 +469,7 @@ type ndpState struct {
rtrSolicit struct {
// The timer used to send the next router solicitation message.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the Router Solicitation timer know that it has been stopped.
//
@@ -503,7 +503,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -515,38 +515,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -561,15 +561,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
ndp.nic.mu.Lock()
defer ndp.nic.mu.Unlock()
@@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
ndp.nic.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6f86abc98..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
@@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
@@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7b80534e6..fea0ce7e8 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 3bc9fd831..a70792b50 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -89,6 +89,11 @@ func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
@@ -238,7 +243,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index e3556d5d2..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -79,6 +79,10 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f260eeb7f..8604c4259 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -52,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -330,8 +333,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -342,6 +344,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -443,6 +455,9 @@ type LinkEndpoint interface {
// See:
// https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -463,12 +478,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 2b7ece851..a6faa22c2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -728,6 +728,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -801,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 48ad56d4d..21aafb0a2 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -203,6 +203,31 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
@@ -316,6 +341,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -555,6 +602,9 @@ type Endpoint interface {
type LinkPacketInfo struct {
// Protocol is the NetworkProtocolNumber for the packet.
Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
}
// PacketEndpoint are additional methods that are only implemented by Packet
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 7f172f978..f32d58091 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -20,7 +20,7 @@
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
@@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index 5554c573f..f1dd7c310 100644
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
@@ -20,50 +20,49 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-// cancellableTimerInstance is a specific instance of CancellableTimer.
+// jobInstance is a specific instance of Job.
//
-// Different instances are created each time CancellableTimer is Reset so each
-// timer has its own earlyReturn signal. This is to address a bug when a
-// CancellableTimer is stopped and reset in quick succession resulting in a
-// timer instance's earlyReturn signal being affected or seen by another timer
-// instance.
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
//
// Consider the following sceneario where timer instances share a common
// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
// (B), third (C), and fourth (D) instance of the timer firing, respectively):
// T1: Obtain L
-// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T1: Create a new Job w/ lock L (create instance A)
// T2: instance A fires, blocked trying to obtain L.
// T1: Attempt to stop instance A (set earlyReturn = true)
-// T1: Reset timer (create instance B)
+// T1: Schedule timer (create instance B)
// T3: instance B fires, blocked trying to obtain L.
// T1: Attempt to stop instance B (set earlyReturn = true)
-// T1: Reset timer (create instance C)
+// T1: Schedule timer (create instance C)
// T4: instance C fires, blocked trying to obtain L.
// T1: Attempt to stop instance C (set earlyReturn = true)
-// T1: Reset timer (create instance D)
+// T1: Schedule timer (create instance D)
// T5: instance D fires, blocked trying to obtain L.
// T1: Release L
//
-// Now that T1 has released L, any of the 4 timer instances can take L and check
-// earlyReturn. If the timers simply check earlyReturn and then do nothing
-// further, then instance D will never early return even though it was not
-// requested to stop. If the timers reset earlyReturn before early returning,
-// then all but one of the timers will do work when only one was expected to.
-// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
// will fire (again, when only one was expected to).
//
// To address the above concerns the simplest solution was to give each timer
// its own earlyReturn signal.
-type cancellableTimerInstance struct {
- timer *time.Timer
+type jobInstance struct {
+ timer Timer
// Used to inform the timer to early return when it gets stopped while the
// lock the timer tries to obtain when fired is held (T1 is a goroutine that
// tries to cancel the timer and T2 is the goroutine that handles the timer
// firing):
- // T1: Obtain the lock, then call StopLocked()
+ // T1: Obtain the lock, then call Cancel()
// T2: timer fires, and gets blocked on obtaining the lock
// T1: Releases lock
// T2: Obtains lock does unintended work
@@ -74,29 +73,33 @@ type cancellableTimerInstance struct {
earlyReturn *bool
}
-// stop stops the timer instance t from firing if it hasn't fired already. If it
+// stop stops the job instance j from firing if it hasn't fired already. If it
// has fired and is blocked at obtaining the lock, earlyReturn will be set to
// true so that it will early return when it obtains the lock.
-func (t *cancellableTimerInstance) stop() {
- if t.timer != nil {
- t.timer.Stop()
- *t.earlyReturn = true
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
}
}
-// CancellableTimer is a timer that does some work and can be safely cancelled
-// when it fires at the same time some "related work" is being done.
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
//
// The term "related work" is defined as some work that needs to be done while
// holding some lock that the timer must also hold while doing some work.
//
-// Note, it is not safe to copy a CancellableTimer as its timer instance creates
-// a closure over the address of the CancellableTimer.
-type CancellableTimer struct {
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
_ sync.NoCopy
+ // The clock used to schedule the backing timer
+ clock Clock
+
// The active instance of a cancellable timer.
- instance cancellableTimerInstance
+ instance jobInstance
// locker is the lock taken by the timer immediately after it fires and must
// be held when attempting to stop the timer.
@@ -113,59 +116,91 @@ type CancellableTimer struct {
fn func()
}
-// StopLocked prevents the Timer from firing if it has not fired already.
+// Cancel prevents the Job from executing if it has not executed already.
//
-// If the timer is blocked on obtaining the t.locker lock when StopLocked is
-// called, it will early return instead of calling t.fn.
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
//
// Note, t will be modified.
//
-// t.locker MUST be locked.
-func (t *CancellableTimer) StopLocked() {
- t.instance.stop()
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
// Nothing to do with the stopped instance anymore.
- t.instance = cancellableTimerInstance{}
+ j.instance = jobInstance{}
}
-// Reset changes the timer to expire after duration d.
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
//
-// Note, t will be modified.
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
//
-// Reset should only be called on stopped or expired timers. To be safe, callers
-// should always call StopLocked before calling Reset.
-func (t *CancellableTimer) Reset(d time.Duration) {
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
// Create a new instance.
earlyReturn := false
// Capture the locker so that updating the timer does not cause a data race
// when a timer fires and tries to obtain the lock (read the timer's locker).
- locker := t.locker
- t.instance = cancellableTimerInstance{
- timer: time.AfterFunc(d, func() {
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
locker.Lock()
defer locker.Unlock()
if earlyReturn {
// If we reach this point, it means that the timer fired while another
- // goroutine called StopLocked while it had the lock. Simply return
- // here and do nothing further.
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
earlyReturn = false
return
}
- t.fn()
+ j.fn()
}),
earlyReturn: &earlyReturn,
}
}
-// NewCancellableTimer returns an unscheduled CancellableTimer with the given
-// locker and fn.
-//
-// fn MUST NOT attempt to lock locker.
-//
-// Callers must call Reset to schedule the timer to fire.
-func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
- return &CancellableTimer{locker: locker, fn: fn}
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index b4940e397..a82384c49 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -28,8 +28,8 @@ const (
longDuration = 1 * time.Second
)
-func TestCancellableTimerReassignment(t *testing.T) {
- var timer tcpip.CancellableTimer
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- timer = *tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
wg.Done()
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
lock.Unlock()
}()
}
wg.Wait()
}
-func TestCancellableTimerFire(t *testing.T) {
+func TestJobExecution(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
ch <- struct{}{}
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
lock.Lock()
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
}
}
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
// Wait for timer to fire if it wasn't correctly stopped.
@@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
case <-time.After(middleDuration):
}
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
}
}
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
+func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
}
@@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) {
}
}
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
+ job.Schedule(middleDuration)
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration * 2)
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
}
@@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
@@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
}
}
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 678f4e016..4612be4e7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -797,7 +797,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 7b2083a09..0e46e6355 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -441,6 +441,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(hdr.SourceAddress()),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
@@ -448,34 +449,57 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(localAddr),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index c2e9fd29f..f85a68554 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -456,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -700,7 +700,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
}
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 18ff89ffc..e860ee484 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -49,6 +49,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 81b740115..1798510bc 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 83dc10ed0..0f7487963 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1209,6 +1209,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1956,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -2546,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b34e47bbd..5d6174a59 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -49,7 +49,7 @@ const (
// DefaultReceiveBufferSize is the default size of the receive buffer
// for an endpoint.
- DefaultReceiveBufferSize = 1 << 20 // 1MB
+ DefaultReceiveBufferSize = 32 << 10 // 32KB
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..5e0bfe585 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// We only store the segment if it's within our buffer
// size limit.
if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ r.pendingBufUsed += seqnum.Size(s.segMemSize())
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.pendingBufUsed -= seqnum.Size(s.segMemSize())
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 0280892a8..bb60dc29d 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -138,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 06fde2a79..37e7767d6 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -143,12 +143,14 @@ func New(t *testing.T, mtu uint32) *Context {
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a14643ae8..6e692da07 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1451,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()