summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/checker/checker.go12
-rw-r--r--pkg/tcpip/header/ipv4.go37
-rw-r--r--pkg/tcpip/header/ipv6.go9
-rw-r--r--pkg/tcpip/header/parse/parse.go3
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go16
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go5
-rw-r--r--pkg/tcpip/link/loopback/loopback.go17
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go8
-rw-r--r--pkg/tcpip/link/nested/BUILD1
-rw-r--r--pkg/tcpip/link/nested/nested.go6
-rw-r--r--pkg/tcpip/link/pipe/pipe.go5
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go8
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go15
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go64
-rw-r--r--pkg/tcpip/link/tun/device.go38
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go6
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go34
-rw-r--r--pkg/tcpip/network/arp/arp_test.go7
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go71
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go207
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go50
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go23
-rw-r--r--pkg/tcpip/network/ip_test.go313
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go52
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go174
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go280
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go45
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go85
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go148
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go158
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go38
-rw-r--r--pkg/tcpip/network/testutil/testutil.go15
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go4
-rw-r--r--pkg/tcpip/socketops.go63
-rw-r--r--pkg/tcpip/stack/forwarding_test.go56
-rw-r--r--pkg/tcpip/stack/nic.go155
-rw-r--r--pkg/tcpip/stack/registration.go17
-rw-r--r--pkg/tcpip/stack/route.go22
-rw-r--r--pkg/tcpip/stack/stack.go48
-rw-r--r--pkg/tcpip/stack/stack_test.go156
-rw-r--r--pkg/tcpip/stack/transport_test.go95
-rw-r--r--pkg/tcpip/tcpip.go23
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go13
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go7
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go9
-rw-r--r--pkg/tcpip/transport/tcp/accept.go143
-rw-r--r--pkg/tcpip/transport/tcp/connect.go126
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go143
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go11
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/rack.go11
-rw-r--r--pkg/tcpip/transport/tcp/snd.go64
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go415
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go58
-rw-r--r--pkg/tcpip/transport/tcp/timer.go4
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go53
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go66
69 files changed, 2473 insertions, 1281 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 454e07662..27f96a3ac 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -5,6 +5,7 @@ package(licenses = ["notice"])
go_library(
name = "tcpip",
srcs = [
+ "socketops.go",
"tcpip.go",
"time_unsafe.go",
"timer.go",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 8868cf4e3..81f762e10 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -904,6 +904,12 @@ func ICMPv4Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
payload := icmpv4.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
@@ -994,6 +1000,12 @@ func ICMPv6Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
payload := icmpv6.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 7e32b31b4..91fe7b6a5 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -89,7 +89,17 @@ type IPv4Fields struct {
// DstAddr is the "destination ip address" of an IPv4 packet.
DstAddr tcpip.Address
- // Options is between 0 and 40 bytes or nil if empty.
+ // Options must be 40 bytes or less as they must fit along with the
+ // rest of the IPv4 header into the maximum size describable in the
+ // IHL field. RFC 791 section 3.1 says:
+ // IHL: 4 bits
+ //
+ // Internet Header Length is the length of the internet header in 32
+ // bit words, and thus points to the beginning of the data. Note that
+ // the minimum value for a correct header is 5.
+ //
+ // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode
+ // more will fail.
Options IPv4Options
}
@@ -275,22 +285,19 @@ func (b IPv4) DestinationAddress() tcpip.Address {
// IPv4Options is a buffer that holds all the raw IP options.
type IPv4Options []byte
-// AllocationSize implements stack.NetOptions.
+// SizeWithPadding implements stack.NetOptions.
// It reports the size to allocate for the Options. RFC 791 page 23 (end of
// section 3.1) says of the padding at the end of the options:
// The internet header padding is used to ensure that the internet
// header ends on a 32 bit boundary.
-func (o IPv4Options) AllocationSize() int {
+func (o IPv4Options) SizeWithPadding() int {
return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1)
}
-// Options returns a buffer holding the options or nil.
+// Options returns a buffer holding the options.
func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
- if hdrLen > IPv4MinimumSize {
- return IPv4Options(b[options:hdrLen:hdrLen])
- }
- return nil
+ return IPv4Options(b[options:hdrLen:hdrLen])
}
// TransportProtocol implements Network.TransportProtocol.
@@ -368,14 +375,20 @@ func (b IPv4) Encode(i *IPv4Fields) {
// worth a bit of optimisation here to keep the copy out of the fast path.
hdrLen := IPv4MinimumSize
if len(i.Options) != 0 {
- // AllocationSize is always >= len(i.Options).
- aLen := i.Options.AllocationSize()
+ // SizeWithPadding is always >= len(i.Options).
+ aLen := i.Options.SizeWithPadding()
hdrLen += aLen
if hdrLen > len(b) {
panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen))
}
- if aLen != copy(b[options:], i.Options) {
- _ = copy(b[options+len(i.Options):options+aLen], []byte{0, 0, 0, 0})
+ opts := b[options:]
+ // This avoids bounds checks on the next line(s) which would happen even
+ // if there's no work to do.
+ if n := copy(opts, i.Options); n != aLen {
+ padding := opts[n:][:aLen-n]
+ for i := range padding {
+ padding[i] = 0
+ }
}
}
b.SetHeaderLength(uint8(hdrLen))
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 4e7e5f76a..55d09355a 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -54,7 +54,7 @@ type IPv6Fields struct {
// NextHeader is the "next header" field of an IPv6 packet.
NextHeader uint8
- // HopLimit is the "hop limit" field of an IPv6 packet.
+ // HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
// SrcAddr is the "source ip address" of an IPv6 packet.
@@ -171,7 +171,7 @@ func (b IPv6) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
-// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+// HopLimit returns the value of the "Hop Limit" field of the ipv6 header.
func (b IPv6) HopLimit() uint8 {
return b[hopLimit]
}
@@ -236,6 +236,11 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
+// SetHopLimit sets the value of the "Hop Limit" field.
+func (b IPv6) SetHopLimit(v uint8) {
+ b[hopLimit] = v
+}
+
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
b[IPv6NextHeaderOffset] = v
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 5ca75c834..2042f214a 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -109,6 +109,9 @@ traverseExtensions:
fragOffset = extHdr.FragmentOffset()
fragMore = extHdr.More()
}
+ rawPayload := it.AsRawHeader(true /* consume */)
+ extensionsSize = dataClone.Size() - rawPayload.Buf.Size()
+ break traverseExtensions
case header.IPv6RawPayloadHeader:
// We've found the payload after any extensions.
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 39ca774ef..973f06cbc 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -9,7 +9,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c95aef63c..a7f5f4979 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -271,21 +270,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := PacketInfo{
- Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }),
- Proto: 0,
- GSO: nil,
- }
-
- e.q.Write(p)
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 975309fc8..fc620c7d5 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -558,11 +558,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool {
return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
-}
-
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 38aa694e4..edca57e4e 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- // There should be an ethernet header at the beginning of vv.
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- // Reject the packet if it's shorter than an ethernet header.
- return tcpip.ErrBadAddress
- }
- linkHeader := header.Ethernet(hdr)
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt)
-
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index e7493e5c5..cbda59775 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 56a611825..22e79ce3a 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -17,7 +17,6 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco
return tcpip.ErrNoRoute
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // WriteRawPacket doesn't get a route or network address, so there's
- // nowhere to write this.
- return tcpip.ErrNoRoute
-}
-
// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index 2cdb23475..00b42b924 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index d40de54df..0ee54c3d5 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -19,7 +19,6 @@ package nested
import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return e.child.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return e.child.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.
func (e *Endpoint) Wait() {
e.child.Wait()
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 523b0d24b..71fcb73e1 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList,
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- panic("not implemented")
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 1d0079bd6..5bea598eb 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -13,7 +13,6 @@ go_library(
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index fc1e34fc7..9b41d60d5 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -197,13 +196,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
return enqueued, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // TODO(gvisor.dev/issue/3267): Queue these packets as well once
- // WriteRawPacket takes PacketBuffer instead of VectorisedView.
- return e.lower.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() {
e.lower.Wait()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 7fb8a6c49..a1e7018c8 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- views := vv.Views()
- // Transmit the packet.
- e.mu.Lock()
- ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
- if !ok {
- return tcpip.ErrWouldBlock
- }
-
- return nil
-}
-
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
// to the network stack.
func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index b3e8c4b92..8d9a91020 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -53,16 +53,35 @@ type endpoint struct {
nested.Endpoint
writer io.Writer
maxPCAPLen uint32
+ logPrefix string
}
var _ stack.GSOEndpoint = (*endpoint)(nil)
var _ stack.LinkEndpoint = (*endpoint)(nil)
var _ stack.NetworkDispatcher = (*endpoint)(nil)
+type direction int
+
+const (
+ directionSend = iota
+ directionRecv
+)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- sniffer := &endpoint{}
+ return NewWithPrefix(lower, "")
+}
+
+// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets prefixed with logPrefix as they traverse
+// the endpoint.
+//
+// logPrefix is prepended to the log line without any separators.
+// E.g. logPrefix = "NIC:en0/" will produce log lines like
+// "NIC:en0/send udp [...]".
+func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint {
+ sniffer := &endpoint{logPrefix: logPrefix}
sniffer.Endpoint.Init(lower, sniffer)
return sniffer
}
@@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, pkt)
+ e.dumpPacket(directionRecv, nil, protocol, pkt)
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(prefix, protocol, pkt, gso)
+ logPacket(e.logPrefix, dir, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Size()
@@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// forwards the request to the lower endpoint.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return e.Endpoint.WriteRawPacket(vv)
-}
-
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var fragmentOffset uint16
var moreFragments bool
+ var directionPrefix string
+ switch dir {
+ case directionSend:
+ directionPrefix = "send"
+ case directionRecv:
+ directionPrefix = "recv"
+ default:
+ panic(fmt.Sprintf("unrecognized direction: %d", dir))
+ }
+
// Clone the packet buffer to not modify the original.
//
// We don't clone the original packet buffer so that the new packet buffer
@@ -248,15 +269,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
- "%s arp %s (%s) -> %s (%s) valid:%t",
+ "%s%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
+ directionPrefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
arp.IsValid(),
)
return
default:
- log.Infof("%s unknown network protocol", prefix)
+ log.Infof("%s%s unknown network protocol", prefix, directionPrefix)
return
}
@@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto)
return
}
@@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 4c14f55d3..9a76bdba7 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -76,29 +76,13 @@ func (d *Device) Release(ctx context.Context) {
}
}
-// NICID returns the NIC ID of the device.
-//
-// Must only be called after the device has been attached to an endpoint.
-func (d *Device) NICID() tcpip.NICID {
- d.mu.RLock()
- defer d.mu.RUnlock()
-
- if d.endpoint == nil {
- panic("called NICID on a device that has not been attached")
- }
-
- return d.endpoint.nicID
-}
-
// SetIff services TUNSETIFF ioctl(2) request.
-//
-// Returns true if a new NIC was created; false if an existing one was attached.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return false, syserror.EINVAL
+ return syserror.EINVAL
}
// Input validations.
@@ -106,7 +90,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error)
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return false, syserror.EINVAL
+ return syserror.EINVAL
}
prefix := "tun"
@@ -119,18 +103,18 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error)
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return false, syserror.EINVAL
+ return syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return created, nil
+ return nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
@@ -138,13 +122,13 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, false, syserror.EOPNOTSUPP
+ return nil, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, false, nil
+ return endpoint, nil
}
}
@@ -167,12 +151,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, true, nil
+ return endpoint, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, false, syserror.EINVAL
+ return nil, syserror.EINVAL
}
}
}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index ee84c3d96..9b4602c1b 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/gate",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
@@ -25,7 +24,6 @@ go_test(
library = ":waitable",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index b152a0f26..cf0077f43 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -24,7 +24,6 @@ package waitable
import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, err
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if !e.writeGate.Enter() {
- return nil
- }
-
- err := e.lower.WriteRawPacket(vv)
- e.writeGate.Leave()
- return err
-}
-
// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
// and waits for inflight ones to finish before returning.
func (e *Endpoint) WaitWrite() {
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 94827fc56..cf7fb5126 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -18,7 +18,6 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.
return pkts.Len(), nil
}
-func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- e.writeCount++
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index c118a2929..b38aff0b8 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -14,6 +14,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 33a4a0720..3d5c0d270 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -31,17 +31,15 @@ import (
const (
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
-
- // ProtocolAddress is the address expected by the ARP endpoint.
- ProtocolAddress = tcpip.Address("arp")
)
-var _ stack.AddressableEndpoint = (*endpoint)(nil)
+// ARP endpoints need to implement stack.NetworkEndpoint because the stack
+// considers the layer above the link-layer a network layer; the only
+// facility provided by the stack to deliver packets to a layer above
+// the link-layer is via stack.NetworkEndpoint.HandlePacket.
var _ stack.NetworkEndpoint = (*endpoint)(nil)
type endpoint struct {
- stack.AddressableEndpointState
-
protocol *protocol
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@@ -87,7 +85,7 @@ func (e *endpoint) Disable() {
}
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
-func (e *endpoint) DefaultTTL() uint8 {
+func (*endpoint) DefaultTTL() uint8 {
return 0
}
@@ -100,25 +98,23 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.AddressableEndpointState.Cleanup()
-}
+func (*endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
-func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -216,9 +212,8 @@ func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
func (p *protocol) DefaultPrefixLen() int { return 0 }
-func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- h := header.ARP(v)
- return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
+ return "", ""
}
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
@@ -228,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
linkAddrCache: linkAddrCache,
nud: nud,
}
- e.AddressableEndpointState.Init(e)
return e
}
@@ -311,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
}
// NewProtocol returns an ARP network protocol.
-//
-// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
-// address must be added to every NIC that should respond to ARP requests. See
-// ProtocolAddress for more details.
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 087ee9c66..f462524c9 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -200,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
}
- if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("AddAddress for arp failed: %v", err)
- }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -439,6 +436,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 47fb63290..d8e4a3b54 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -47,6 +47,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
"//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 936601287..c75ca7d71 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -71,16 +71,25 @@ type FragmentID struct {
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
type Fragmentation struct {
- mu sync.Mutex
- highLimit int
- lowLimit int
- reassemblers map[FragmentID]*reassembler
- rList reassemblerList
- size int
- timeout time.Duration
- blockSize uint16
- clock tcpip.Clock
- releaseJob *tcpip.Job
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[FragmentID]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+ blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
+ timeoutHandler TimeoutHandler
+}
+
+// TimeoutHandler is consulted if a packet reassembly has timed out.
+type TimeoutHandler interface {
+ // OnReassemblyTimeout will be called with the first fragment (or nil, if the
+ // first fragment has not been received) of a packet whose reassembly has
+ // timed out.
+ OnReassemblyTimeout(pkt *stack.PacketBuffer)
}
// NewFragmentation creates a new Fragmentation.
@@ -97,7 +106,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -111,12 +120,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
}
f := &Fragmentation{
- reassemblers: make(map[FragmentID]*reassembler),
- highLimit: highMemoryLimit,
- lowLimit: lowMemoryLimit,
- timeout: reassemblingTimeout,
- blockSize: blockSize,
- clock: clock,
+ reassemblers: make(map[FragmentID]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ blockSize: blockSize,
+ clock: clock,
+ timeoutHandler: timeoutHandler,
}
f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
@@ -136,16 +146,8 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
-//
-// releaseCB is a callback that will run when the fragment reassembly of a
-// packet is complete or cancelled. releaseCB take a a boolean argument which is
-// true iff the reassembly is cancelled due to timeout. releaseCB should be
-// passed only with the first fragment of a packet. If more than one releaseCB
-// are passed for the same packet, only the first releaseCB will be saved for
-// the packet and the succeeding ones will be dropped by running them
-// immediately with a false argument.
func (f *Fragmentation) Process(
- id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
+ id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
buffer.VectorisedView, uint8, bool, error) {
if first > last {
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
@@ -160,10 +162,9 @@ func (f *Fragmentation) Process(
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
- if l := vv.Size(); l < int(fragmentSize) {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ if l := pkt.Data.Size(); l != int(fragmentSize) {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
- vv.CapLength(int(fragmentSize))
f.mu.Lock()
r, ok := f.reassemblers[id]
@@ -179,15 +180,9 @@ func (f *Fragmentation) Process(
f.releaseReassemblersLocked()
}
}
- if releaseCB != nil {
- if !r.setCallback(releaseCB) {
- // We got a duplicate callback. Release it immediately.
- releaseCB(false /* timedOut */)
- }
- }
f.mu.Unlock()
- res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
+ res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
@@ -231,7 +226,9 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) {
f.size = 0
}
- r.release(timedOut) // releaseCB may run.
+ if h := f.timeoutHandler; timedOut && h != nil {
+ h.OnReassemblyTimeout(r.pkt)
+ }
}
// releaseReassemblersLocked releases already-expired reassemblers, then
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 5dcd10730..3a79688a8 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// reassembleTimeout is dummy timeout used for testing, where the clock never
@@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
return buffer.NewVectorisedView(size, views)
}
+func pkt(size int, pieces ...string) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv(size, pieces...),
+ })
+}
+
type processInput struct {
id FragmentID
first uint16
last uint16
more bool
proto uint8
- vv buffer.VectorisedView
+ pkt *stack.PacketBuffer
}
type processOutput struct {
@@ -63,8 +70,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -74,8 +81,8 @@ var processTestCases = []struct {
{
comment: "Next Header protocol mismatch",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -85,10 +92,10 @@ var processTestCases = []struct {
{
comment: "Two IDs",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -102,17 +109,17 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
if err != nil {
- t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
- in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
}
if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)",
- in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView())
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView())
}
if done != c.out[i].done {
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
@@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
for _, event := range test.events {
clock.Advance(event.clockAdvance)
if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err)
}
@@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
got := f.size
want := 1
@@ -327,6 +334,7 @@ func TestErrors(t *testing.T) {
last: 3,
more: true,
data: "012",
+ err: ErrInvalidArgs,
},
{
name: "exact block size with more and too little data",
@@ -376,8 +384,8 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
@@ -498,57 +506,92 @@ func TestPacketFragmenter(t *testing.T) {
}
}
-func TestReleaseCallback(t *testing.T) {
+type testTimeoutHandler struct {
+ pkt *stack.PacketBuffer
+}
+
+func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ h.pkt = pkt
+}
+
+func TestTimeoutHandler(t *testing.T) {
const (
proto = 99
)
- var result int
- var callbackReasonIsTimeout bool
- cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
- cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+ pk1 := pkt(1, "1")
+ pk2 := pkt(1, "2")
+
+ type processParam struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ }
tests := []struct {
- name string
- callbacks []func(bool)
- timeout bool
- wantResult int
- wantCallbackReasonIsTimeout bool
+ name string
+ params []processParam
+ wantError bool
+ wantPkt *stack.PacketBuffer
}{
{
- name: "callback runs on release",
- callbacks: []func(bool){cb1},
- timeout: false,
- wantResult: 1,
- wantCallbackReasonIsTimeout: false,
- },
- {
- name: "first callback is nil",
- callbacks: []func(bool){nil, cb2},
- timeout: false,
- wantResult: 2,
- wantCallbackReasonIsTimeout: false,
+ name: "onTimeout runs",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
},
{
- name: "two callbacks - first one is set",
- callbacks: []func(bool){cb1, cb2},
- timeout: false,
- wantResult: 1,
- wantCallbackReasonIsTimeout: false,
+ name: "no first fragment",
+ params: []processParam{
+ {
+ first: 1,
+ last: 1,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: nil,
},
{
- name: "callback runs on timeout",
- callbacks: []func(bool){cb1},
- timeout: true,
- wantResult: 1,
- wantCallbackReasonIsTimeout: true,
+ name: "second pkt is ignored",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk2,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
},
{
- name: "no callbacks",
- callbacks: []func(bool){nil},
- timeout: false,
- wantResult: 0,
- wantCallbackReasonIsTimeout: false,
+ name: "invalid args - first is greater than last",
+ params: []processParam{
+ {
+ first: 1,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: true,
+ wantPkt: nil,
},
}
@@ -556,29 +599,31 @@ func TestReleaseCallback(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- result = 0
- callbackReasonIsTimeout = false
+ handler := &testTimeoutHandler{pkt: nil}
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
- for i, cb := range test.callbacks {
- _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
- if err != nil {
+ for _, p := range test.params {
+ if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
t.Errorf("f.Process error = %s", err)
}
}
-
- r, ok := f.reassemblers[id]
- if !ok {
- t.Fatalf("Reassemberr not found")
- }
- f.release(r, test.timeout)
-
- if result != test.wantResult {
- t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ if !test.wantError {
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatal("Reassembler not found")
+ }
+ f.release(r, true)
}
- if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
- t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ switch {
+ case handler.pkt != nil && test.wantPkt == nil:
+ t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView())
+ case handler.pkt == nil && test.wantPkt != nil:
+ t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView())
+ case handler.pkt != nil && test.wantPkt != nil:
+ if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" {
+ t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
+ }
}
})
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index c0cc0bde0..19f4920b3 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
type hole struct {
@@ -41,7 +42,7 @@ type reassembler struct {
heap fragHeap
done bool
creationTime int64
- callback func(bool)
+ pkt *stack.PacketBuffer
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
@@ -79,7 +80,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -89,18 +90,20 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf
// was waiting on the mutex. We don't have to do anything in this case.
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- // For IPv6, it is possible to have different Protocol values between
- // fragments of a packet (because, unlike IPv4, the Protocol is not used to
- // identify a fragment). In this case, only the Protocol of the first
- // fragment must be used as per RFC 8200 Section 4.5.
- //
- // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded
- // here (instead of just the protocol) because most IP options should be
- // derived from the first fragment.
- if first == 0 {
- r.proto = proto
- }
if r.updateHoles(first, last, more) {
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP
+ // options received in the first fragment should be used - and they should
+ // override options from following fragments.
+ if first == 0 {
+ r.pkt = pkt
+ r.proto = proto
+ }
+ vv := pkt.Data
// We store the incoming packet only if it filled some holes.
heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
consumed = vv.Size()
@@ -124,24 +127,3 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
-
-func (r *reassembler) setCallback(c func(bool)) bool {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.callback != nil {
- return false
- }
- r.callback = c
- return true
-}
-
-func (r *reassembler) release(timedOut bool) {
- r.mu.Lock()
- callback := r.callback
- r.callback = nil
- r.mu.Unlock()
-
- if callback != nil {
- callback(timedOut)
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index fa2a70dc8..a0a04a027 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -105,26 +105,3 @@ func TestUpdateHoles(t *testing.T) {
}
}
}
-
-func TestSetCallback(t *testing.T) {
- result := 0
- reasonTimeout := false
-
- cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
- cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
-
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- if !r.setCallback(cb1) {
- t.Errorf("setCallback failed")
- }
- if r.setCallback(cb2) {
- t.Errorf("setCallback should fail if one is already set")
- }
- r.release(true)
- if result != 1 {
- t.Errorf("got result = %d, want = 1", result)
- }
- if !reasonTimeout {
- t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
- }
-}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index c7d26e14f..787399e08 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -34,16 +35,16 @@ import (
)
const (
- localIPv4Addr = "\x0a\x00\x00\x01"
- remoteIPv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02")
+ ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00")
+ ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00")
+ ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03")
+ localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
+ ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00")
+ ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
nicID = 1
)
@@ -192,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu
panic("not implemented")
}
-func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
@@ -299,6 +296,10 @@ func (t *testInterface) Enabled() bool {
return !t.mu.disabled
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) setEnabled(v bool) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -558,59 +559,135 @@ func TestIPv4Send(t *testing.T) {
}
}
-func TestIPv4Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
+func TestReceive(t *testing.T) {
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ v4 bool
+ epAddr tcpip.AddressWithPrefix
+ handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ v4: true,
+ epAddr: localIPv4Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const totalLen = header.IPv4MinimumSize + 30 /* payload length */
+
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ TTL: ipv4.DefaultTTL,
+ Protocol: 10,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if ok := parse.IPv4(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
},
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ v4: false,
+ epAddr: localIPv6Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const payloadLen = 30
+ view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadLen,
+ NextHeader: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
+ })
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
- totalLen := header.IPv4MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, _, _, ok := parse.IPv6(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
+ },
}
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: test.v4,
+ },
+ }
+ ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
+ }
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ } else {
+ ep.DecRef()
+ }
+
+ stat := s.Stats().IP.PacketsReceived
+ if got := stat.Value(); got != 0 {
+ t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ test.handlePacket(t, ep, &nic)
+ if nic.testObject.dataCalls != 1 {
+ t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ }
+ if got := stat.Value(); got != 1 {
+ t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ })
}
}
@@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -705,8 +778,18 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -716,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
func TestIPv4FragmentationReceive(t *testing.T) {
- s := buildDummyStack(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
testObject: testObject{
@@ -774,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
nic.testObject.dstAddr = localIPv4Addr
nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
// Send first segment.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
@@ -786,7 +866,18 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
+
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
@@ -799,7 +890,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
@@ -852,61 +942,6 @@ func TestIPv6Send(t *testing.T) {
}
}
-func TestIPv6Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- totalLen := header.IPv6MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
-
- r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
- }
-}
-
func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
@@ -933,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
- r, err := buildIPv6Route(
- localIPv6Addr,
- "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
- )
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -1018,8 +1046,17 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv6Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -1202,7 +1239,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.AllocationSize()
+ ipHdrLen := header.IPv4MinimumSize + ipv4Options.SizeWithPadding()
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1247,7 +1284,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.AllocationSize()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.SizeWithPadding()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 9b5e37fee..488945226 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -90,7 +90,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
iph := header.IPv4(pkt.NetworkHeader().View())
var newOptions header.IPv4Options
- if len(iph) > header.IPv4MinimumSize {
+ if opts := iph.Options(); len(opts) != 0 {
// RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
// type ICMP packets):
// If a Record Route and/or Time Stamp option is received in an
@@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op)
+ aux, tmp, err := e.processIPOptions(pkt, opts, op)
if err != nil {
switch {
case
@@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in
+// transit to its final destination, as per RFC 792 page 6, Time Exceeded
+// Message.
+type icmpReasonTTLExceeded struct{}
+
+func (*icmpReasonTTLExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -342,11 +349,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a TTL Exceeded error, then we know we are operating as a router.
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ //
+ // ...
+ //
+ // Code 0 may be received from a gateway. ...
+ //
+ // Note, Code 0 is the TTL exceeded error.
+ //
+ // If we are operating as a router/gateway, don't use the packet's destination
+ // address as the response's source address as we should not not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonTTLExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonTTLExceeded:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4TTLExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
@@ -483,3 +514,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
counter.Increment()
return nil
}
+
+// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
+func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards the
+ // datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ if pkt != nil {
+ p.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a376cb8ec..1efe6297a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -206,12 +206,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
if opts, ok = params.Options.(header.IPv4Options); !ok {
panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
}
- hdrLen += opts.AllocationSize()
+ hdrLen += opts.SizeWithPadding()
if hdrLen > header.IPv4MaximumHeaderSize {
// Since we have no way to report an error we must either panic or create
// a packet which is different to what was requested. Choose panic as this
// would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.AllocationSize(), header.IPv4MaximumOptionsSize))
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.SizeWithPadding(), header.IPv4MaximumOptionsSize))
}
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
@@ -260,16 +260,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -286,24 +283,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -374,8 +374,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -400,9 +399,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -479,16 +479,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt)
+ return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ ttl := h.TTL()
+ if ttl == 0 {
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ }
+
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 791 page 30, Time to Live,
+ //
+ // This field must be decreased at each point that the internet header
+ // is processed to reflect the time spent processing the datagram.
+ // Even if no local information is available on the time actually
+ // spent, the field must be decremented by 1.
+ newHdr.SetTTL(ttl - 1)
+
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(newHdr).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -497,6 +566,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment()
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to
@@ -528,15 +612,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
@@ -565,29 +650,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Set up a callback in case we need to send a Time Exceeded Message, as per
- // RFC 792:
- //
- // If a host reassembling a fragmented datagram cannot complete the
- // reassembly due to missing fragments within its time limit it discards
- // the datagram, and it may send a time exceeded message.
- //
- // If fragment zero is not available then no time exceeded need be sent at
- // all.
- var releaseCB func(bool)
- if start == 0 {
- pkt := pkt.Clone()
- releaseCB = func(timedOut bool) {
- if timedOut {
- _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
- }
- }
- }
-
- var ready bool
- var err error
proto := h.Protocol()
- pkt.Data, _, ready, err = e.protocol.fragmentation.Process(
+ data, _, ready, err := e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
@@ -600,8 +664,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
start+uint16(pkt.Data.Size())-1,
h.More(),
proto,
- pkt.Data,
- releaseCB,
+ pkt,
)
if err != nil {
stats.IP.MalformedPacketsReceived.Increment()
@@ -611,6 +674,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !ready {
return
}
+ pkt.Data = data
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
@@ -628,11 +692,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
- if len(h.Options()) != 0 {
+ if opts := h.Options(); len(opts) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
+ aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{})
if err != nil {
switch {
case
@@ -778,6 +842,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
@@ -942,13 +1007,14 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{
- stack: s,
- ids: ids,
- hashIV: hashIV,
- defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
+ p := &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
}
+ p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ return p
}
func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index c6e565455..4e4e1f3b4 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -103,6 +103,262 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
+// fields when options are supplied.
+func TestIPv4EncodeOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ options header.IPv4Options
+ encodedOptions header.IPv4Options // reply should look like this
+ wantIHL int
+ }{
+ {
+ name: "valid no options",
+ wantIHL: header.IPv4MinimumSize,
+ },
+ {
+ name: "one byte options",
+ options: header.IPv4Options{1},
+ encodedOptions: header.IPv4Options{1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "two byte options",
+ options: header.IPv4Options{1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "three byte options",
+ options: header.IPv4Options{1, 1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 1, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "four byte options",
+ options: header.IPv4Options{1, 1, 1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 1, 1},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "five byte options",
+ options: header.IPv4Options{1, 1, 1, 1, 1},
+ encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 8,
+ },
+ {
+ name: "thirty nine byte options",
+ options: header.IPv4Options{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 31, 32,
+ 33, 34, 35, 36, 37, 38, 39,
+ },
+ encodedOptions: header.IPv4Options{
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24,
+ 25, 26, 27, 28, 29, 30, 31, 32,
+ 33, 34, 35, 36, 37, 38, 39, 0,
+ },
+ wantIHL: header.IPv4MinimumSize + 40,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ paddedOptionLength := test.options.SizeWithPadding()
+ ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength)
+ hdr := buffer.NewPrependable(int(totalLen))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ // To check the padding works, poison the last byte of the options space.
+ if paddedOptionLength != len(test.options) {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0xff
+ ip.SetHeaderLength(0)
+ }
+ ip.Encode(&header.IPv4Fields{
+ Options: test.options,
+ })
+ options := ip.Options()
+ wantOptions := test.encodedOptions
+ if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
+ t.Errorf("got IHL of %d, want %d", got, want)
+ }
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(wantOptions) == 0 && len(options) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(wantOptions, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv4Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ ipv4Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
+ remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv4Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv4Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
+ }
+
+ totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr1,
+ DstAddr: remoteIPv4Addr2,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv4Addr1.Address),
+ checker.DstAddr(remoteIPv4Addr1),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4TTLExceeded),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv4Addr1),
+ checker.DstAddr(remoteIPv4Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
+
// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
// checks the response.
func TestIPv4Sanity(t *testing.T) {
@@ -197,6 +453,14 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{1, 1, 0, 0},
},
{
+ name: "Check option padding",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{1, 1, 1},
+ replyOptions: header.IPv4Options{1, 1, 1, 0},
+ },
+ {
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
maxTotalLength: ipv4.MaxTotalSize,
@@ -599,9 +863,10 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- ipHeaderLength := header.IPv4MinimumSize + test.options.AllocationSize()
+ paddedOptionLength := test.options.SizeWithPadding()
+ ipHeaderLength := header.IPv4MinimumSize + paddedOptionLength
if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(totalLen))
@@ -618,6 +883,12 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
+ // To check the padding works, poison the options space.
+ if paddedOptionLength != len(test.options) {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0x01
+ }
+
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
Protocol: test.transportProtocol,
@@ -732,7 +1003,7 @@ func TestIPv4Sanity(t *testing.T) {
}
// If the IP options change size then the packet will change size, so
// some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - len(test.options)
+ sizeChange := len(test.replyOptions) - paddedOptionLength
checker.IPv4(t, replyIPHeader,
checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
@@ -2441,9 +2712,6 @@ func TestPacketQueing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 8502b848c..beb8f562e 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -750,6 +750,12 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
+// transit to its final destination, as per RFC 4443 section 3.3.
+type icmpReasonHopLimitExceeded struct{}
+
+func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -794,11 +800,27 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a Hop Limit Exceeded error, then we know we are operating as a
+ // router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ //
+ // If we are operating as a router, do not use the packet's destination
+ // address as the response's source address as we should not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -811,8 +833,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -830,6 +850,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
}
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
// As per RFC 4443 section 2.4
//
// (c) Every ICMPv6 error message (type < 128) MUST include
@@ -873,6 +895,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonHopLimitExceeded:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
@@ -896,3 +922,16 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
counter.Increment()
return nil
}
+
+// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
+func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section
+ // 4.5:
+ //
+ // If the first fragment (i.e., the one with a Fragment Offset of zero) has
+ // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded
+ // message should be sent to the source of that fragment.
+ if pkt != nil {
+ p.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 76013daa1..9bc02d851 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
@@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: test.useNeighborCache,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -206,11 +205,16 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -279,10 +283,9 @@ func TestICMPCounts(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -290,7 +293,7 @@ func TestICMPCounts(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -317,13 +320,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: true,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -349,11 +347,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -422,10 +425,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -433,7 +435,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -1775,17 +1777,19 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
-
- // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
- r.LocalAddress = test.destination
icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.IPv6MinimumSize,
Data: buffer.View(icmp).ToVectorisedView(),
@@ -1795,10 +1799,9 @@ func TestCallsToNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.RemoteAddress,
- DstAddr: r.LocalAddress,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0526190cc..7a00f6314 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt, params.Protocol)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -640,16 +639,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt, proto)
+ return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv6(pkt.NetworkHeader().View())
+ hopLimit := h.HopLimit()
+ if hopLimit <= 1 {
+ // As per RFC 4443 section 3.3,
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ }
+
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 8200 section 3,
+ //
+ // Hop Limit 8-bit unsigned integer. Decremented by 1 by
+ // each node that forwards the packet.
+ newHdr.SetHopLimit(hopLimit - 1)
+
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(newHdr).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -669,6 +737,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ addressEndpoint.DecRef()
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -681,8 +761,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
@@ -888,18 +967,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Set up a callback in case we need to send a Time Exceeded Message as
- // per RFC 2460 Section 4.5.
- var releaseCB func(bool)
- if start == 0 {
- pkt := pkt.Clone()
- releaseCB = func(timedOut bool) {
- if timedOut {
- _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
- }
- }
- }
-
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
data, proto, ready, err := e.protocol.fragmentation.Process(
@@ -914,17 +981,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
start+uint16(fragmentPayloadLen)-1,
extHdr.More(),
uint8(rawPayload.Identifier),
- rawPayload.Buf,
- releaseCB,
+ pkt,
)
if err != nil {
stats.IP.MalformedPacketsReceived.Increment()
stats.IP.MalformedFragmentsReceived.Increment()
return
}
- pkt.Data = data
if ready {
+ pkt.Data = data
+
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
// 8200 section 4.5. We also use the NextHeader value from the first
@@ -1335,6 +1402,7 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
@@ -1590,10 +1658,9 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
- stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
- ids: ids,
- hashIV: hashIV,
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
ndpDisp: opts.NDPDisp,
ndpConfigs: opts.NDPConfigs,
@@ -1601,6 +1668,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
tempIIDSeed: opts.TempIIDSeed,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
}
+ p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[*endpoint]struct{})
p.SetDefaultTTL(DefaultTTL)
return p
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1bfcdde25..a671d4bac 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -18,6 +18,7 @@ import (
"encoding/hex"
"fmt"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
@@ -2821,3 +2822,160 @@ func TestFragmentationErrors(t *testing.T) {
})
}
}
+
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv6Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ }
+ ipv6Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ }
+ remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
+ remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv6Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv6Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
+ }
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
+ icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv6EchoRequest)
+ icmp.SetCode(header.ICMPv6UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
+ })
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv6Addr1.Address),
+ checker.DstAddr(remoteIPv6Addr1),
+ checker.TTL(DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
+ checker.ICMPv6Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv6Addr1),
+ checker.DstAddr(remoteIPv6Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode),
+ checker.ICMPv6Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 981d1371a..37e8b1083 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
-
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
@@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
}
t.Cleanup(ep.Close)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := llladdr.WithPrefix()
+ if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ addressEP.DecRef()
+ }
+
return s, ep
}
@@ -961,22 +968,17 @@ func TestNDPValidation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
// Create a stack with the assigned link-local address lladdr0
// and an endpoint to lladdr1.
s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
+ return s, ep
}
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
var extensions buffer.View
if atomicFragment {
@@ -994,13 +996,12 @@ func TestNDPValidation(t *testing.T) {
PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -1114,8 +1115,7 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ s, ep := setup(t)
if isRouter {
// Enabling forwarding makes the stack act as a router.
@@ -1131,7 +1131,7 @@ func TestNDPValidation(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -1152,7 +1152,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 7cc52985e..5c3363759 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st
return n, nil
}
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
-
- return nil
-}
-
// Attach implements LinkEndpoint.Attach.
func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 8e0ee1cd7..1c2afd554 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -148,10 +148,6 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- log.Fatal(err)
- }
-
subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
if err != nil {
log.Fatal(err)
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
new file mode 100644
index 000000000..e1b0d6354
--- /dev/null
+++ b/pkg/tcpip/socketops.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "sync/atomic"
+)
+
+// SocketOptions contains all the variables which store values for SOL_SOCKET
+// level options.
+//
+// +stateify savable
+type SocketOptions struct {
+ // These fields are accessed and modified using atomic operations.
+
+ // broadcastEnabled determines whether datagram sockets are allowed to send
+ // packets to a broadcast address.
+ broadcastEnabled uint32
+
+ // passCredEnabled determines whether SCM_CREDENTIALS socket control messages
+ // are enabled.
+ passCredEnabled uint32
+}
+
+func storeAtomicBool(addr *uint32, v bool) {
+ var val uint32
+ if v {
+ val = 1
+ }
+ atomic.StoreUint32(addr, val)
+}
+
+// GetBroadcast gets value for SO_BROADCAST option.
+func (so *SocketOptions) GetBroadcast() bool {
+ return atomic.LoadUint32(&so.broadcastEnabled) != 0
+}
+
+// SetBroadcast sets value for SO_BROADCAST option.
+func (so *SocketOptions) SetBroadcast(v bool) {
+ storeAtomicBool(&so.broadcastEnabled, v)
+}
+
+// GetPassCred gets value for SO_PASSCRED option.
+func (so *SocketOptions) GetPassCred() bool {
+ return atomic.LoadUint32(&so.passCredEnabled) != 0
+}
+
+// SetPassCred sets value for SO_PASSCRED option.
+func (so *SocketOptions) SetPassCred(v bool) {
+ storeAtomicBool(&so.passCredEnabled, v)
+}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 7a501acdc..cb7dec1ea 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ netHdr := pkt.NetworkHeader().View()
+ _, dst := f.proto.ParseAddresses(netHdr)
+
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt)
+ return
+ }
+
+ r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ defer r.Release()
+
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ pkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: vv.ToView().ToVectorisedView(),
+ })
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ _ = r.WriteHeaderIncludedPacket(pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+ // The network header should not already be populated.
+ if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
}
func (f *fwdTestNetworkEndpoint) Close() {
@@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() {
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
+ stack *Stack
+
addrCache *linkAddrCache
neigh *neighborCache
addrResolveDelay time.Duration
@@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
- NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
+ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
+ proto.stack = s
+ return proto
+ }},
UseNeighborCache: useNeighborCache,
})
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 60c81a3aa..3e6ceff28 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -232,7 +232,8 @@ func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) isPromiscuousMode() bool {
+// Promiscuous implements NetworkInterface.
+func (n *NIC) Promiscuous() bool {
n.mu.RLock()
rv := n.mu.promiscuous
n.mu.RUnlock()
@@ -320,16 +321,21 @@ func (n *NIC) setSpoofing(enable bool) {
// primaryAddress returns an address that can be used to communicate with
// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
- n.mu.RLock()
- spoofing := n.mu.spoofing
- n.mu.RUnlock()
-
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+
+ return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -388,11 +394,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre
// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
// is passed to indicate whether or not we should generate temporary endpoints.
func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
- if ep, ok := n.networkEndpoints[protocol]; ok {
- return ep.AcquireAssignedAddress(address, createTemp, peb)
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return nil
}
- return nil
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb)
}
// addAddress adds a new address to n, so that it starts accepting packets
@@ -403,7 +415,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return tcpip.ErrUnknownProtocol
}
- addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -416,7 +433,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PermanentAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PermanentAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -427,7 +449,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PrimaryAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PrimaryAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -445,13 +472,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
return tcpip.AddressWithPrefix{}
}
- return ep.MainAddress()
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ return addressableEndpoint.MainAddress()
}
// removeAddress removes an address from n.
func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
for _, ep := range n.networkEndpoints {
- if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
continue
} else {
return err
@@ -564,13 +601,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return false
}
-func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
- defer r.Release()
- r.PopulatePacketInfo(pkt)
- n.getNetworkEndpoint(protocol).HandlePacket(pkt)
-}
-
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the link endpoint.
@@ -592,7 +622,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
- netProto, ok := n.stack.networkProtocols[protocol]
+ networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -617,11 +647,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
ep.HandlePacket(n.id, local, protocol, p)
}
- if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
- n.stack.stats.IP.PacketsReceived.Increment()
- }
-
// Parse headers.
+ netProto := n.stack.NetworkProtocolInstance(protocol)
transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
// The packet is too small to contain a network header.
@@ -636,9 +663,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
-
if n.stack.handleLocal && !n.IsLoopback() {
+ src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View())
if r := n.getAddress(protocol, src); r != nil {
r.DecRef()
@@ -651,78 +677,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- // Loopback traffic skips the prerouting chain.
- if !n.IsLoopback() {
- // iptables filtering.
- ipt := n.stack.IPTables()
- address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
- // iptables is telling us to drop the packet.
- n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
- return
- }
- }
-
- if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
- n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
- return
- }
-
- // This NIC doesn't care about the packet. Find a NIC that cares about the
- // packet and forward it to the NIC.
- //
- // TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding(protocol) {
- r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
- if err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- // Found a NIC.
- n := r.localAddressNIC
- if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
- if n.isValidForOutgoing(addressEndpoint) {
- pkt.NICID = n.ID()
- r.RemoteAddress = src
- pkt.NetworkPacketInfo = r.networkPacketInfo()
- n.getNetworkEndpoint(protocol).HandlePacket(pkt)
- addressEndpoint.DecRef()
- r.Release()
- return
- }
-
- addressEndpoint.DecRef()
- }
-
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease
- // the TTL field for ipv4/ipv6.
-
- // pkt may have set its header and may not have enough headroom for
- // link-layer header for the other link to prepend. Here we create a new
- // packet to forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- // We need to do a deep copy of the IP packet because WritePacket (and
- // friends) take ownership of the packet buffer, but we do not own it.
- Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
- })
-
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
-
- r.Release()
- return
- }
-
- // If a packet socket handled the packet, don't treat it as invalid.
- if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
+ networkEndpoint.HandlePacket(pkt)
}
// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 00e9a82ae..43ca03ada 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -65,10 +65,6 @@ const (
// NetworkPacketInfo holds information about a network layer packet.
type NetworkPacketInfo struct {
- // RemoteAddressBroadcast is true if the packet's remote address is a
- // broadcast address.
- RemoteAddressBroadcast bool
-
// LocalAddressBroadcast is true if the packet's local address is a broadcast
// address.
LocalAddressBroadcast bool
@@ -266,10 +262,10 @@ const (
// NetOptions is an interface that allows us to pass network protocol specific
// options through the Stack layer code.
type NetOptions interface {
- // AllocationSize returns the amount of memory that must be allocated to
+ // SizeWithPadding returns the amount of memory that must be allocated to
// hold the options given that the value must be rounded up to the next
// multiple of 4 bytes.
- AllocationSize() int
+ SizeWithPadding() int
}
// NetworkHeaderParams are the header parameters given as input by the
@@ -518,6 +514,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+ // Promiscuous returns true if the interface is in promiscuous mode.
+ Promiscuous() bool
+
// WritePacketToRemote writes the packet to the given remote link address.
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
@@ -525,8 +524,6 @@ type NetworkInterface interface {
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
- AddressableEndpoint
-
// Enable enables the endpoint.
//
// Must only be called when the stack is in a state that allows the endpoint
@@ -742,10 +739,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 15ff437c7..53cb6694f 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -170,28 +170,6 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-// PopulatePacketInfo populates a packet buffer's packet information fields.
-//
-// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
-// the network layer.
-func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
- if r.local() {
- pkt.RXTransportChecksumValidated = true
- }
- pkt.NetworkPacketInfo = r.networkPacketInfo()
-}
-
-// networkPacketInfo returns the network packet information of the route.
-//
-// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
-// the network layer.
-func (r *Route) networkPacketInfo() NetworkPacketInfo {
- return NetworkPacketInfo{
- RemoteAddressBroadcast: r.IsOutboundBroadcast(),
- LocalAddressBroadcast: r.isInboundBroadcast(),
- }
-}
-
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.outgoingNIC.ID()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0fe157128..f4504e633 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -82,6 +82,7 @@ type TCPRACKState struct {
FACK seqnum.Value
RTT time.Duration
Reord bool
+ DSACKSeen bool
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
@@ -1080,7 +1081,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
Running: nic.Enabled(),
- Promiscuous: nic.isPromiscuousMode(),
+ Promiscuous: nic.Promiscuous(),
Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
@@ -1809,49 +1810,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dedfdd435..0d94af139 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -112,7 +112,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
netHdr := pkt.NetworkHeader().View()
- f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
+
+ dst := tcpip.Address(netHdr[dstAddrOffset:][:1])
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ return
+ }
+ addressEndpoint.DecRef()
+
+ f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
@@ -159,9 +167,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.Clone()
- r.PopulatePacketInfo(pkt)
- f.HandlePacket(pkt)
+ f.HandlePacket(pkt.Clone())
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -2214,88 +2220,6 @@ func TestNICStats(t *testing.T) {
}
}
-func TestNICForwarding(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const dstAddr = tcpip.Address("\x03")
-
- tests := []struct {
- name string
- headerLen uint16
- }{
- {
- name: "Zero header length",
- },
- {
- name: "Non-zero header length",
- headerLen: 16,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
- }
-
- ep2 := channelLinkWithHeaderLength{
- Endpoint: channel.New(10, defaultMTU, ""),
- headerLength: test.headerLen,
- }
- if err := s.CreateNIC(nicID2, &ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
- }
-
- // Route all packets to dstAddr to NIC 2.
- {
- subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
- }
-
- // Send a packet to dstAddr.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not forwarded")
- }
-
- // Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
- t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestNICContextPreservation tests that you can read out via stack.NICInfo the
// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
func TestNICContextPreservation(t *testing.T) {
@@ -4228,3 +4152,63 @@ func TestFindRouteWithForwarding(t *testing.T) {
})
}
}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if got, want := pkt.Route.RemoteLinkAddress, linkAddr2; got != want {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index c457b67a2..5b9043d85 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -47,6 +46,9 @@ type fakeTransportEndpoint struct {
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
+
+ // ops is used to set and get socket options.
+ ops tcpip.SocketOptions
}
func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
@@ -59,6 +61,9 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
+ return &f.ops
+}
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
@@ -184,9 +189,9 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
- a := f.acceptQueue[0]
+ a := &f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
- return &a, nil, nil
+ return a, nil, nil
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
@@ -553,87 +558,3 @@ func TestTransportOptions(t *testing.T) {
t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
-
-func TestTransportForwarding(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- // TODO(b/123449044): Change this to a channel NIC.
- ep1 := loopback.New()
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2 and all packets to address
- // 1 to NIC 1.
- {
- subnet0, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- })
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Send a packet to address 1 from address 3.
- req := buffer.NewView(30)
- req[0] = 1
- req[1] = 3
- req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: req.ToVectorisedView(),
- }))
-
- aep, _, err := ep.Accept(nil)
- if err != nil || aep == nil {
- t.Fatalf("Accept failed: %v, %v", aep, err)
- }
-
- resp := buffer.NewView(30)
- if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- p, ok := ep2.Read()
- if !ok {
- t.Fatal("Response packet not forwarded")
- }
-
- nh := stack.PayloadSince(p.Pkt.NetworkHeader())
- if dst := nh[0]; dst != 3 {
- t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
- }
- if src := nh[1]; src != 1 {
- t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
- }
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 3ab2b7654..09361360f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -634,6 +634,10 @@ type Endpoint interface {
// LastError clears and returns the last error reported by the endpoint.
LastError() *Error
+
+ // SocketOptions returns the structure which contains all the socket
+ // level options.
+ SocketOptions() *SocketOptions
}
// LinkPacketInfo holds Link layer information for a received packet.
@@ -694,15 +698,10 @@ type WriteOptions struct {
type SockOptBool int
const (
- // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether datagram sockets are allowed to send packets to a broadcast
- // address.
- BroadcastOption SockOptBool = iota
-
// CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
// data should be held until segments are full by the TCP transport
// protocol.
- CorkOption
+ CorkOption SockOptBool = iota
// DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
// data should be sent out immediately by the transport protocol. For
@@ -722,12 +721,6 @@ const (
// whether UDP checksum is disabled for this socket.
NoChecksumOption
- // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether SCM_CREDENTIALS socket control messages are enabled.
- //
- // Only supported on Unix sockets.
- PasscredOption
-
// QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
QuickAckOption
@@ -1464,9 +1457,13 @@ type ICMPStats struct {
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
// PacketsReceived is the total number of IP packets received from the
- // link layer in nic.DeliverNetworkPacket.
+ // link layer.
PacketsReceived *StatCounter
+ // DisabledPacketsReceived is the total number of IP packets received from the
+ // link layer when the IP layer is disabled.
+ DisabledPacketsReceived *StatCounter
+
// InvalidDestinationAddressesReceived is the total number of IP packets
// received with an unknown or invalid destination address.
InvalidDestinationAddressesReceived *StatCounter
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index bf7594268..39343b966 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -229,19 +229,6 @@ func TestForwarding(t *testing.T) {
t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
-
if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index fe7c1bb3d..bf8a1241f 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -140,13 +140,6 @@ func TestPing(t *testing.T) {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
-
if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 1eecd7957..9d30329f5 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -514,9 +514,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
- }
+ ep.SocketOptions().SetBroadcast(true)
bindAddr := tcpip.FullAddress{Port: localPort}
if bindWildcard {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 763cd8f84..fe6514bcd 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -79,6 +79,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -853,3 +856,8 @@ func (*endpoint) Wait() {}
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 31831a6d8..3bff3755a 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -89,6 +89,9 @@ type endpoint struct {
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// NewEndpoint returns a new packet endpoint.
@@ -549,3 +552,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
}
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+
+func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &ep.ops
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7b6a87ba9..0a1e1fbb3 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -89,6 +89,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -753,6 +756,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+// LastError implements tcpip.Endpoint.LastError.
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 47982ca41..6e5adc383 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n, nil
}
-// createEndpointAndPerformHandshake creates a new endpoint in connected state
-// and then performs the TCP 3-way handshake.
+// startHandshake creates a new endpoint in connecting state and then sends
+// the SYN-ACK for the TCP 3-way handshake. It returns the state of the
+// handshake in progress, which includes the new endpoint in the SYN-RCVD
+// state.
//
-// The new endpoint is returned with e.mu held.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+// On success, a handshake h is returned with h.ep.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
deferAccept := time.Duration(0)
if l.listenEP != nil {
- l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
- l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
@@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- l.listenEP.mu.Unlock()
- }
+ l.removePendingEndpoint(ep)
return nil, tcpip.ErrConnectionAborted
}
deferAccept = l.listenEP.deferAccept
- l.listenEP.mu.Unlock()
}
// Register new endpoint so that packets are routed to it.
@@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.isRegistered = true
- // Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
- if err := h.execute(); err != nil {
- ep.mu.Unlock()
- ep.Close()
- ep.notifyAborted()
-
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
- ep.drainClosingSegmentQueue()
-
+ // Initialize and start the handshake.
+ h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ if err := h.start(); err != nil {
+ l.cleanupFailedHandshake(h)
return nil, err
}
- ep.isConnectNotified = true
+ return h, nil
+}
- // Update the receive window scaling. We can't do it before the
- // handshake because it's possible that the peer doesn't support window
- // scaling.
- ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+// performHandshake performs a TCP 3-way handshake. On success, the new
+// established endpoint is returned with e.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+ h, err := l.startHandshake(s, opts, queue, owner)
+ if err != nil {
+ return nil, err
+ }
+ ep := h.ep
+ if err := h.complete(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.stats.FailedConnectionAttempts.Increment()
+ l.cleanupFailedHandshake(h)
+ return nil, err
+ }
+ l.cleanupCompletedHandshake(h)
return ep, nil
}
@@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() {
l.pending.Wait()
}
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupFailedHandshake(h *handshake) {
+ e := h.ep
+ e.mu.Unlock()
+ e.Close()
+ e.notifyAborted()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.drainClosingSegmentQueue()
+ e.h = nil
+}
+
+// cleanupCompletedHandshake transfers any state from the completed handshake to
+// the new endpoint.
+//
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
+ e := h.ep
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.isConnectNotified = true
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+
+ // Clean up handshake state stored in the endpoint so that it can be GCed.
+ e.h = nil
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
@@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() {
//
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer ctx.synRcvdCount.dec()
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.decSynRcvdCount()
- return
+ e.synRcvdCount--
+ return err
}
- ctx.removePendingEndpoint(n)
- e.decSynRcvdCount()
- n.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(n)
+ go func() {
+ defer ctx.synRcvdCount.dec()
+ if err := h.complete(); err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+ e.deliverAccepted(h.ep)
+ }() // S/R-SAFE: synRcvdCount is the barrier.
+
+ return nil
}
func (e *endpoint) incSynRcvdCount() bool {
@@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool {
return canInc
}
-func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
@@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
@@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// backlog.
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
- go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ _ = e.handleSynSegment(ctx, s, &opts)
return nil
}
ctx.synRcvdCount.dec()
@@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
// to the endpoint.
e.setEndpointState(StateClose)
- // close any endpoints in SYN-RCVD state.
+ // Close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
// Do cleanup if needed.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6e9015be1..6661e8915 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -102,21 +102,26 @@ type handshake struct {
// been received. This is required to stop retransmitting the
// original SYN-ACK when deferAccept is enabled.
acked bool
+
+ // sendSYNOpts is the cached values for the SYN options to be sent.
+ sendSYNOpts header.TCPSynOptions
}
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- h := handshake{
- ep: ep,
+func (e *endpoint) newHandshake() *handshake {
+ h := &handshake{
+ ep: e,
active: true,
- rcvWnd: rcvWnd,
- rcvWndScale: ep.rcvWndScaleForHandshake(),
+ rcvWnd: seqnum.Size(e.initialReceiveWindow()),
+ rcvWndScale: e.rcvWndScaleForHandshake(),
}
h.resetState()
+ // Store reference to handshake state in endpoint.
+ e.h = h
return h
}
-func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
- h := newHandshake(ep, rcvWnd)
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+ h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
}
@@ -491,7 +496,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
}
@@ -502,8 +507,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
}
-// execute executes the TCP 3-way handshake.
-func (h *handshake) execute() *tcpip.Error {
+// start resolves the route if necessary and sends the first
+// SYN/SYN-ACK.
+func (h *handshake) start() *tcpip.Error {
if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil {
return err
@@ -511,19 +517,7 @@ func (h *handshake) execute() *tcpip.Error {
}
h.startTime = time.Now()
- // Initialize the resend timer.
- resendWaker := sleep.Waker{}
- timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, resendWaker.Assert)
- defer rt.Stop()
-
- // Set up the wakers.
- s := sleep.Sleeper{}
- s.AddWaker(&resendWaker, wakerForResend)
- s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
- s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
- defer s.Done()
-
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
@@ -531,10 +525,6 @@ func (h *handshake) execute() *tcpip.Error {
sackEnabled = false
}
- // Send the initial SYN segment and loop until the handshake is
- // completed.
- h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
-
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
@@ -544,9 +534,8 @@ func (h *handshake) execute() *tcpip.Error {
MSS: h.ep.amss,
}
- // Execute is also called in a listen context so we want to make sure we
- // only send the TS/SACK option when we received the TS/SACK in the
- // initial SYN.
+ // start() is also called in a listen context so we want to make sure we only
+ // send the TS/SACK option when we received the TS/SACK in the initial SYN.
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
@@ -557,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.sendSYNOpts = synOpts
h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
@@ -566,7 +556,25 @@ func (h *handshake) execute() *tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, synOpts)
+ return nil
+}
+
+// complete completes the TCP 3-way handshake initiated by h.start().
+func (h *handshake) complete() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resendWaker := sleep.Waker{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+ // Initialize the resend timer.
+ timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+ if err != nil {
+ return err
+ }
+ defer timer.stop()
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
@@ -576,11 +584,9 @@ func (h *handshake) execute() *tcpip.Error {
switch index {
case wakerForResend:
- timeOut *= 2
- if timeOut > MaxRTO {
- return tcpip.ErrTimeout
+ if err := timer.reset(); err != nil {
+ return err
}
- rt.Reset(timeOut)
// Resend the SYN/SYN-ACK only if the following conditions hold.
// - It's an active handshake (deferAccept does not apply)
// - It's a passive handshake and we have not yet got the final-ACK.
@@ -598,7 +604,7 @@ func (h *handshake) execute() *tcpip.Error {
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
- }, synOpts)
+ }, h.sendSYNOpts)
}
case wakerForNotification:
@@ -624,9 +630,8 @@ func (h *handshake) execute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
-
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
return err
@@ -637,6 +642,34 @@ func (h *handshake) execute() *tcpip.Error {
return nil
}
+type backoffTimer struct {
+ timeout time.Duration
+ maxTimeout time.Duration
+ t *time.Timer
+}
+
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+ if timeout > maxTimeout {
+ return nil, tcpip.ErrTimeout
+ }
+ bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
+ bt.t = time.AfterFunc(timeout, f)
+ return bt, nil
+}
+
+func (bt *backoffTimer) reset() *tcpip.Error {
+ bt.timeout *= 2
+ if bt.timeout > MaxRTO {
+ return tcpip.ErrTimeout
+ }
+ bt.t.Reset(bt.timeout)
+ return nil
+}
+
+func (bt *backoffTimer) stop() {
+ bt.t.Stop()
+}
+
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
if synOpts.TS {
@@ -967,7 +1000,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
@@ -1106,7 +1139,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.HardError = tcpip.ErrAborted
+ e.hardError = tcpip.ErrAborted
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1318,7 +1351,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
epilogue := func() {
// e.mu is expected to be hold upon entering this section.
-
if e.snd != nil {
e.snd.resendTimer.cleanup()
}
@@ -1342,20 +1374,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if handshake {
- // This is an active connection, so we must initiate the 3-way
- // handshake, and then inform potential waiters about its
- // completion.
- initialRcvWnd := e.initialReceiveWindow()
- h := newHandshake(e, seqnum.Size(initialRcvWnd))
- h.ep.setEndpointState(StateSynSent)
-
- if err := h.execute(); err != nil {
+ if err := e.h.complete(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
e.workerCleanup = true
// Lock released below.
@@ -1364,9 +1389,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
}
- e.keepalive.timer.init(&e.keepalive.waker)
- defer e.keepalive.timer.cleanup()
-
drained := e.drainDone != nil
if drained {
close(e.drainDone)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 23b9de8c5..36b915510 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -315,11 +315,6 @@ func (*Stats) IsEndpointStats() {}
// +stateify savable
type EndpointInfo struct {
stack.TransportEndpointInfo
-
- // HardError is meaningful only when state is stateError. It stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. HardError is protected by endpoint mu.
- HardError *tcpip.Error `state:".(string)"`
}
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -386,6 +381,11 @@ type endpoint struct {
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
+ // hardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by endpoint mu.
+ hardError *tcpip.Error `state:".(string)"`
+
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -440,9 +440,11 @@ type endpoint struct {
ttl uint8
v6only bool
isConnectNotified bool
- // TCP should never broadcast but Linux nevertheless supports enabling/
- // disabling SO_BROADCAST, albeit as a NOOP.
- broadcast bool
+
+ // h stores a reference to the current handshake state if the endpoint is in
+ // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep.
+ // nil otherwise.
+ h *handshake `state:"nosave"`
// portFlags stores the current values of port related flags.
portFlags ports.Flags
@@ -685,6 +687,9 @@ type endpoint struct {
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -922,6 +927,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
return e
}
@@ -1146,6 +1152,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
e.closePendingAcceptableConnectionsLocked()
+ e.keepalive.timer.cleanup()
e.workerCleanup = false
@@ -1272,11 +1279,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-func (e *endpoint) LastError() *tcpip.Error {
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) hardErrorLocked() *tcpip.Error {
+ err := e.hardError
+ e.hardError = nil
+ return err
+}
+
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) lastErrorLocked() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1284,6 +1300,16 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// LastError implements tcpip.Endpoint.LastError.
+func (e *endpoint) LastError() *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return e.lastErrorLocked()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1305,9 +1331,11 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.HardError
if s == StateError {
- return buffer.View{}, tcpip.ControlMessages{}, he
+ if err := e.hardErrorLocked(); err != nil {
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
@@ -1363,9 +1391,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+ // The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
- return 0, e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return 0, err
+ }
+ return 0, tcpip.ErrClosedForSend
case !s.connecting() && !s.connected():
return 0, tcpip.ErrClosedForSend
case s.connecting():
@@ -1479,7 +1511,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.HardError
+ return 0, tcpip.ControlMessages{}, e.hardErrorLocked()
}
e.stats.ReadErrors.InvalidEndpointState.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@@ -1599,11 +1631,6 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
- case tcpip.BroadcastOption:
- e.LockUser()
- e.broadcast = v
- e.UnlockUser()
-
case tcpip.CorkOption:
e.LockUser()
if !v {
@@ -1950,11 +1977,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.BroadcastOption:
- e.LockUser()
- v := e.broadcast
- e.UnlockUser()
- return v, nil
case tcpip.CorkOption:
return atomic.LoadUint32(&e.cork) != 0, nil
@@ -2185,6 +2207,8 @@ func (*endpoint) Disconnect() *tcpip.Error {
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
err := e.connect(addr, true, true)
if err != nil && !err.IgnoreStats() {
+ // Connect failed. Let's wake up any waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
}
@@ -2244,7 +2268,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return tcpip.ErrConnectionAborted
default:
return tcpip.ErrInvalidEndpointState
@@ -2397,14 +2424,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if run {
- e.workerRunning = true
- e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ if err := e.startMainLoop(handshake); err != nil {
+ return err
+ }
}
return tcpip.ErrConnectStarted
}
+// startMainLoop sends the initial SYN and starts the main loop for the
+// endpoint.
+func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
+ preloop := func() *tcpip.Error {
+ if handshake {
+ h := e.newHandshake()
+ e.setEndpointState(StateSynSent)
+ if err := h.start(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.setEndpointState(StateError)
+ e.hardError = err
+
+ // Call cleanupLocked to free up any reservations.
+ e.cleanupLocked()
+ return err
+ }
+ }
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ return nil
+ }
+
+ if e.route.IsResolutionRequired() {
+ // If the endpoint is closed between releasing e.mu and the goroutine below
+ // acquiring it, make sure that cleanup is deferred to the new goroutine.
+ e.workerRunning = true
+
+ // Sending the initial SYN may block due to route resolution; do it in a
+ // separate goroutine to avoid blocking the syscall goroutine.
+ go func() { // S/R-SAFE: will be drained before save.
+ e.mu.Lock()
+ if err := preloop(); err != nil {
+ e.workerRunning = false
+ e.mu.Unlock()
+ return
+ }
+ e.mu.Unlock()
+ _ = e.protocolMainLoop(handshake, nil)
+ }()
+ return nil
+ }
+
+ // No route resolution is required, so we can send the initial SYN here without
+ // blocking. This will hopefully reduce overall latency by overlapping time
+ // spent waiting for a SYN-ACK and time spent spinning up a new goroutine
+ // for the main loop.
+ if err := preloop(); err != nil {
+ return err
+ }
+ e.workerRunning = true
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ return nil
+}
+
// ConnectEndpoint is not supported.
func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
@@ -3061,6 +3144,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
FACK: rc.fack,
RTT: rc.rtt,
Reord: rc.reorderSeen,
+ DSACKSeen: rc.dsackSeen,
}
return s
}
@@ -3130,3 +3214,8 @@ func (e *endpoint) Wait() {
<-notifyCh
}
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 2bcc5e1c2..ba67176b5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() {
// Condition variables and mutexs are not S/R'ed so reinitialize
// acceptCond with e.acceptMu.
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -320,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
}
// saveHardError is invoked by stateify.
-func (e *EndpointInfo) saveHardError() string {
- if e.HardError == nil {
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
return ""
}
- return e.HardError.String()
+ return e.hardError.String()
}
// loadHardError is invoked by stateify.
-func (e *EndpointInfo) loadHardError(s string) {
+func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
- e.HardError = tcpip.StringToError(s)
+ e.hardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 0664789da..596178625 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index d312b1b8b..e0a50a919 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -29,12 +29,12 @@ import (
//
// +stateify savable
type rackControl struct {
+ // dsackSeen indicates if the connection has seen a DSACK.
+ dsackSeen bool
+
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
- // dsack indicates if the connection has seen a DSACK.
- dsack bool
-
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
@@ -122,3 +122,8 @@ func (rc *rackControl) detectReorder(seg *segment) {
rc.reorderSeen = true
}
}
+
+// setDSACKSeen updates rack control if duplicate SACK is seen by the connection.
+func (rc *rackControl) setDSACKSeen() {
+ rc.dsackSeen = true
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index ab5fa4fb7..0e0fdf14c 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1285,25 +1285,29 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
- if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 {
+ // Look for DSACK block.
+ idx := 0
+ n := len(rcvdSeg.parsedOptions.SACKBlocks)
+ if s.checkDSACK(rcvdSeg) {
+ s.rc.setDSACKSeen()
+ idx = 1
+ n--
+ }
+
+ if n == 0 {
return
}
// Sort the SACK blocks. The first block is the most recent unacked
// block. The following blocks can be in arbitrary order.
- sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
- copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks)
+ sackBlocks := make([]header.SACKBlock, n)
+ copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks[idx:])
sort.Slice(sackBlocks, func(i, j int) bool {
return sackBlocks[j].Start.LessThan(sackBlocks[i].Start)
})
seg := s.writeList.Front()
for _, sb := range sackBlocks {
- // This check excludes DSACK blocks.
- if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) {
- continue
- }
-
for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
@@ -1315,6 +1319,50 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
}
}
+// checkDSACK checks if a DSACK is reported and updates it in RACK.
+func (s *sender) checkDSACK(rcvdSeg *segment) bool {
+ n := len(rcvdSeg.parsedOptions.SACKBlocks)
+ if n == 0 {
+ return false
+ }
+
+ sb := rcvdSeg.parsedOptions.SACKBlocks[0]
+ // Check if SACK block is invalid.
+ if sb.End.LessThan(sb.Start) {
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc2883#section-5 DSACK is sent in
+ // at most one SACK block. DSACK is detected in the below two cases:
+ // * If the SACK sequence space is less than this cumulative ACK, it is
+ // an indication that the segment identified by the SACK block has
+ // been received more than once by the receiver.
+ // * If the sequence space in the first SACK block is greater than the
+ // cumulative ACK, then the sender next compares the sequence space
+ // in the first SACK block with the sequence space in the second SACK
+ // block, if there is one. This comparison can determine if the first
+ // SACK block is reporting duplicate data that lies above the
+ // cumulative ACK.
+ if sb.Start.LessThan(rcvdSeg.ackNumber) {
+ return true
+ }
+
+ if n > 1 {
+ sb1 := rcvdSeg.parsedOptions.SACKBlocks[1]
+ if sb1.End.LessThan(sb1.Start) {
+ return false
+ }
+
+ // If the first SACK block is fully covered by second SACK
+ // block, then the first block is a DSACK block.
+ if sb.End.LessThanEq(sb1.End) && sb1.Start.LessThanEq(sb.Start) {
+ return true
+ }
+ }
+
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index d3f92b48c..9818ffa0f 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -30,15 +30,17 @@ const (
maxPayload = 10
tsOptionSize = 12
maxTCPOptionSize = 40
+ mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
)
// TestRACKUpdate tests the RACK related fields are updated when an ACK is
// received on a SACK enabled connection.
func TestRACKUpdate(t *testing.T) {
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ c := context.New(t, uint32(mtu))
defer c.Cleanup()
var xmitTime time.Time
+ probeDone := make(chan struct{})
c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
// Validate that the endpoint Sender.RACKState is what we expect.
if state.Sender.RACKState.XmitTime.Before(xmitTime) {
@@ -54,6 +56,7 @@ func TestRACKUpdate(t *testing.T) {
if state.Sender.RACKState.RTT == 0 {
t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
}
+ close(probeDone)
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
@@ -73,18 +76,20 @@ func TestRACKUpdate(t *testing.T) {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
- time.Sleep(200 * time.Millisecond)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
}
// TestRACKDetectReorder tests that RACK detects packet reordering.
func TestRACKDetectReorder(t *testing.T) {
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ c := context.New(t, uint32(mtu))
defer c.Cleanup()
- const ackNum = 2
-
var n int
- ch := make(chan struct{})
+ const ackNumToVerify = 2
+ probeDone := make(chan struct{})
c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
gotSeq := state.Sender.RACKState.FACK
wantSeq := state.Sender.SndNxt
@@ -95,7 +100,7 @@ func TestRACKDetectReorder(t *testing.T) {
}
n++
- if n < ackNum {
+ if n < ackNumToVerify {
if state.Sender.RACKState.Reord {
t.Fatalf("RACK reorder detected when there is no reordering")
}
@@ -105,11 +110,11 @@ func TestRACKDetectReorder(t *testing.T) {
if state.Sender.RACKState.Reord == false {
t.Fatalf("RACK reorder detection failed")
}
- close(ch)
+ close(probeDone)
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(ackNum * maxPayload)
+ data := buffer.NewView(ackNumToVerify * maxPayload)
for i := range data {
data[i] = byte(i)
}
@@ -120,7 +125,7 @@ func TestRACKDetectReorder(t *testing.T) {
}
bytesRead := 0
- for i := 0; i < ackNum; i++ {
+ for i := 0; i < ackNumToVerify; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
}
@@ -133,5 +138,393 @@ func TestRACKDetectReorder(t *testing.T) {
// Wait for the probe function to finish processing the ACK before the
// test completes.
- <-ch
+ <-probeDone
+}
+
+func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View {
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(numPackets * maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ for i := 0; i < numPackets; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ return data
+}
+
+const (
+ validDSACKDetected = 1
+ failedToDetectDSACK = 2
+ invalidDSACKDetected = 3
+)
+
+func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) {
+ var n int
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that RACK detects DSACK.
+ n++
+ if n < numACK {
+ if state.Sender.RACKState.DSACKSeen {
+ probeDone <- invalidDSACKDetected
+ }
+ return
+ }
+
+ if !state.Sender.RACKState.DSACKSeen {
+ probeDone <- failedToDetectDSACK
+ return
+ }
+ probeDone <- validDSACKDetected
+ })
+}
+
+// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments.
+// See: https://tools.ietf.org/html/rfc2883#section-4.1.1.
+func TestRACKDetectDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 2
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 8
+ data := sendAndReceive(t, c, numPackets)
+
+ // Cumulative ACK for [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Expect retransmission of #6 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Send DSACK block for #6 packet indicating both
+ // initial and retransmitted packet are received and
+ // packets [1-7] are received.
+ start := c.IRS.Add(seqnum.Size(bytesRead))
+ end := start.Add(maxPayload)
+ bytesRead += 2 * maxPayload
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of
+// order segments.
+// See: https://tools.ietf.org/html/rfc2883#section-4.1.2.
+func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 2
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 10
+ data := sendAndReceive(t, c, numPackets)
+
+ // Cumulative ACK for [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Expect retransmission of #6 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Send DSACK block for #6 packet indicating both
+ // initial and retransmitted packet are received and
+ // packets [1-7] are received.
+ start := c.IRS.Add(seqnum.Size(bytesRead))
+ end := start.Add(maxPayload)
+ bytesRead += 2 * maxPayload
+ // Send DSACK block for #6 along with out of
+ // order #9 packet is received.
+ start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload)
+ end1 := start1.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}})
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a
+// duplicate of out of order packet.
+// See: https://tools.ietf.org/html/rfc2883#section-4.1.3
+func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 4
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 10
+ sendAndReceive(t, c, numPackets)
+
+ // ACK [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Send SACK indicating #6 packet is missing and received #7 packet.
+ offset := seqnum.Size(bytesRead + maxPayload)
+ start := c.IRS.Add(1 + offset)
+ end := start.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Send SACK with #6 packet is missing and received [7-8] packets.
+ end = start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Consider #8 packet is duplicated on the network and send DSACK.
+ dsackStart := c.IRS.Add(1 + offset + maxPayload)
+ dsackEnd := dsackStart.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment.
+// See: https://tools.ietf.org/html/rfc2883#section-4.2.1.
+func TestRACKDetectDSACKSingleDup(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 4
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 4
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-3] packets and received #4 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Expect retransmission of #2 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // ACK for retransmitted #2 packet.
+ bytesRead += maxPayload
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by
+ // sending DSACK block for the subsegment.
+ dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
+ c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous
+// duplicate subsegments covered by the cumulative acknowledgement.
+// See: https://tools.ietf.org/html/rfc2883#section-4.2.2.
+func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 5
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 6
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-5] packets and received #6 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(5*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Expect retransmission of #2 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Received delayed #2 packet.
+ bytesRead += maxPayload
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Received delayed #4 packet.
+ start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
+ end1 := start1.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
+
+ // Simulate receiving retransmitted subsegment for #2 packet and delayed #3
+ // packet by sending DSACK block for #2 packet.
+ dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
+ c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not
+// covered by the cumulative acknowledgement.
+// See: https://tools.ietf.org/html/rfc2883#section-4.2.3.
+func TestRACKDetectDSACKDup(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 5
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 7
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-6] packets and SACK #7 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Received delayed #3 packet.
+ start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
+ end1 := start1.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
+
+ // Expect retransmission of #2 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Consider #2 packet has been dropped and SACK #4 packet.
+ start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
+ end2 := start2.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}})
+
+ // Simulate receiving retransmitted subsegment for #3 packet and delayed #5
+ // packet by sending DSACK block for the subsegment.
+ dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
+ end1 = end1.Add(seqnum.Size(2 * maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK
+// is not the first SACK block.
+func TestRACKWithInvalidDSACKBlock(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ const ackNumToVerify = 2
+ var n int
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that RACK does not detect DSACK when DSACK block is
+ // not the first SACK block.
+ n++
+ t.Helper()
+ if state.Sender.RACKState.DSACKSeen {
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+
+ if n == ackNumToVerify {
+ close(probeDone)
+ }
+ })
+
+ numPackets := 10
+ data := sendAndReceive(t, c, numPackets)
+
+ // Cumulative ACK for [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Expect retransmission of #6 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Send DSACK block for #6 packet indicating both
+ // initial and retransmitted packet are received and
+ // packets [1-7] are received.
+ start := c.IRS.Add(seqnum.Size(bytesRead))
+ end := start.Add(maxPayload)
+ bytesRead += 2 * maxPayload
+
+ // Send DSACK block as second block.
+ start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload)
+ end1 := start1.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fcc3c5000..c366a4cbc 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.LastError(); err != tcpip.ErrAborted {
- t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
- }
// Call Connect again to retreive the handshake failure status
// and stats updates.
@@ -3198,6 +3195,11 @@ loop:
case tcpip.ErrWouldBlock:
select {
case <-ch:
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
@@ -3207,14 +3209,10 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
- }
+
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
-
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
@@ -5717,6 +5715,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSYNRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send the same SYN packet multiple times. We should still get a valid SYN-ACK
+ // reply.
+ irs := seqnum.Value(789)
+ for i := 0; i < 5; i++ {
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Receive the SYN-ACK reply.
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
+}
+
func TestSynRcvdBadSeqNumber(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 7981d469b..38a335840 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
+ if t.timer == nil {
+ // No cleanup needed.
+ return
+ }
t.timer.Stop()
*t = timer{}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c78549424..7ebae63d8 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9bcb918bb..81601f559 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -108,7 +108,6 @@ type endpoint struct {
multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
- broadcast bool
noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
@@ -157,6 +156,9 @@ type endpoint struct {
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// +stateify savable
@@ -427,7 +429,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
to := opts.To
e.mu.RLock()
- defer e.mu.RUnlock()
+ lockReleased := false
+ defer func() {
+ if lockReleased {
+ return
+ }
+ e.mu.RUnlock()
+ }()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
@@ -473,7 +481,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.state != StateConnected {
err = tcpip.ErrInvalidEndpointState
}
- return
+ return ch, err
}
} else {
// Reject destination address if it goes through a different
@@ -508,7 +516,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
resolve = route.Resolve
}
- if !e.broadcast && route.IsOutboundBroadcast() {
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
return 0, nil, tcpip.ErrBroadcastDisabled
}
@@ -539,7 +547,24 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
+ localPort := e.ID.LocalPort
+ sendTOS := e.sendTOS
+ owner := e.owner
+ noChecksum := e.noChecksum
+ lockReleased = true
+ e.mu.RUnlock()
+
+ // Do not hold lock when sending as loopback is synchronous and if the UDP
+ // datagram ends up generating an ICMP response then it can result in a
+ // deadlock where the ICMP response handling ends up acquiring this endpoint's
+ // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
+ // deadlock if another caller is trying to acquire e.mu in exclusive mode w/
+ // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
+ // lock can be eventually acquired.
+ //
+ // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
+ // locking is prohibited.
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -553,11 +578,6 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v
- e.mu.Unlock()
-
case tcpip.MulticastLoopOption:
e.mu.Lock()
e.multicastLoop = v
@@ -614,7 +634,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.v6only = v
}
-
return nil
}
@@ -830,12 +849,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.BroadcastOption:
- e.mu.RLock()
- v := e.broadcast
- e.mu.RUnlock()
- return v, nil
-
case tcpip.KeepaliveEnabledOption:
return false, nil
@@ -1522,6 +1535,12 @@ func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c09c7aa86..492e277a8 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,6 +55,7 @@ const (
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
+ invalidPort = 8192
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
@@ -295,7 +297,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
})
}
@@ -364,9 +367,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.t.Fatalf("SetSockOptBool failed: %s", err)
}
} else if flow.isBroadcast() {
- if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetBroadcast(true)
}
}
@@ -974,7 +975,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
// provided.
func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, true, checkers...)
+ return testWriteAndVerifyInternal(c, flow, true, checkers...)
}
// testWriteWithoutDestination sends a packet of the given test flow from the
@@ -983,10 +984,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker
// checker functions provided.
func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, false, checkers...)
+ return testWriteAndVerifyInternal(c, flow, false, checkers...)
}
-func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
@@ -1008,6 +1009,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
c.checkEndpointWriteStats(1, epstats, err)
+ return payload
+}
+
+func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -1152,6 +1159,39 @@ func TestV4WriteOnConnected(t *testing.T) {
testWriteWithoutDestination(c, unicastV4)
}
+func TestWriteOnConnectedInvalidPort(t *testing.T) {
+ protocols := map[string]tcpip.NetworkProtocolNumber{
+ "ipv4": ipv4.ProtocolNumber,
+ "ipv6": ipv6.ProtocolNumber,
+ }
+ for name, pn := range protocols {
+ t.Run(name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(pn)
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
+ }
+ payload := buffer.View(newPayload())
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ if err != nil {
+ c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ }
+ if got, want := n, int64(len(payload)); got != want {
+ c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
+ }
+
+ if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
+ })
+ }
+}
+
// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
// that is bound to a V4 multicast address.
func TestWriteOnBoundToV4Multicast(t *testing.T) {
@@ -1451,6 +1491,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -2393,17 +2437,13 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
- }
+ ep.SocketOptions().SetBroadcast(true)
if n, _, err := ep.Write(data, opts); err != nil {
t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
- t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
- }
+ ep.SocketOptions().SetBroadcast(false)
if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)