summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/BUILD8
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go73
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/BUILD38
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go362
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go72
-rw-r--r--pkg/tcpip/transport/raw/BUILD21
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go47
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/raw/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/BUILD13
-rw-r--r--pkg/tcpip/transport/tcp/accept.go69
-rw-r--r--pkg/tcpip/transport/tcp/connect.go556
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go308
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go5
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go86
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go190
-rw-r--r--pkg/tcpip/transport/tcp/rcv_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/segment.go15
-rw-r--r--pkg/tcpip/transport/tcp/snd.go48
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go1961
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go77
-rw-r--r--pkg/tcpip/transport/udp/BUILD9
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go146
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go9
-rw-r--r--pkg/tcpip/transport/udp/protocol.go33
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go33
31 files changed, 3868 insertions, 415 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9254c3dea..d8c5b5058 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -38,11 +38,3 @@ go_library(
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "icmp_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 3187b336b..9c40931b5 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -31,9 +31,6 @@ type icmpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
type endpointState int
@@ -58,6 +55,7 @@ type endpoint struct {
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: stateInitial,
+ uniqueID: s.UniqueID(),
}, nil
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -274,13 +278,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicid := to.NIC
+ nicID := to.NIC
if e.BindNICID != 0 {
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
}
toCopy := *to
@@ -291,7 +295,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -425,7 +429,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: data.ToVectorisedView(),
+ TransportHeader: buffer.View(icmpv4),
+ })
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -445,13 +453,17 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
return tcpip.ErrInvalidEndpointState
}
- icmpv6.SetChecksum(0)
- icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+ dataVV := data.ToVectorisedView()
+ icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: dataVV,
+ TransportHeader: buffer.View(icmpv6),
+ })
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
@@ -479,7 +491,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicid := addr.NIC
+ nicID := addr.NIC
localPort := uint16(0)
switch e.state {
case stateBound, stateConnected:
@@ -488,11 +500,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
break
}
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -503,7 +515,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -520,14 +532,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
e.route = r.Clone()
- e.RegisterNICID = nicid
+ e.RegisterNICID = nicID
e.state = stateConnected
@@ -578,18 +590,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -714,18 +726,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h := header.ICMPv4(vv.First())
+ h := header.ICMPv4(pkt.Data.First())
if h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h := header.ICMPv6(vv.First())
+ h := header.ICMPv6(pkt.Data.First())
if h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -753,19 +765,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &icmpPacket{
+ packet := &icmpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
},
}
- pkt.data = vv.Clone(pkt.views[:])
+ packet.data = pkt.Data
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += pkt.data.Size()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
- pkt.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -776,7 +788,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
@@ -798,3 +810,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index bfb16f7c3..9ce500e80 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
new file mode 100644
index 000000000..44b58ff6b
--- /dev/null
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -0,0 +1,38 @@
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "packet_list",
+ out = "packet_list.go",
+ package = "packet",
+ prefix = "packet",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*packet",
+ "Linker": "*packet",
+ },
+)
+
+go_library(
+ name = "packet",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "packet_list.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
new file mode 100644
index 000000000..0010b5e5f
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -0,0 +1,362 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packet provides the implementation of packet sockets (see
+// packet(7)). Packet sockets allow applications to:
+//
+// * manually write and inspect link, network, and transport headers
+// * receive all traffic of a given network protocol, or all protocols
+//
+// Packet sockets are similar to raw sockets, but provide even more power to
+// users, letting them effectively talk directly to the network device.
+//
+// Packet sockets skip the input and output iptables chains.
+package packet
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
+// to have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+ cooked bool
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+}
+
+// NewEndpoint returns a new packet endpoint.
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ },
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
+ return nil, err
+ }
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ ep.closed = true
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
+// IPTables implements tcpip.Endpoint.IPTables.
+func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+ return ep.stack.IPTables(), nil
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ ep.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(b/129292371): Implement.
+ return 0, nil, tcpip.ErrInvalidOptionValue
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
+// disconnected, and this function always returns tpcip.ErrNotSupported.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
+// connected, and this function always returnes tcpip.ErrNotSupported.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
+// with Shutdown, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
+// Listen, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
+// Accept, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO(gvisor.dev/issue/173): Add Bind support.
+
+ // "By default, all packets of the specified protocol type are passed
+ // to a packet socket. To get packets only from a specific interface
+ // use bind(2) specifying an address in a struct sockaddr_ll to bind
+ // the packet socket to an interface. Fields used for binding are
+ // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
+ // - packet(7).
+
+ return tcpip.ErrNotSupported
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
+// used with SetSockOpt, and this function always returns
+// tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// HandlePacket implements stack.PacketEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ var packet packet
+ // TODO(b/129292371): Return network protocol.
+ if len(pkt.LinkHeader) > 0 {
+ // Get info directly from the ethernet header.
+ hdr := header.Ethernet(pkt.LinkHeader)
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(hdr.SourceAddress()),
+ }
+ } else {
+ // Guess the would-be ethernet header.
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(localAddr),
+ }
+ }
+
+ if ep.cooked {
+ // Cooked packets can simply be queued.
+ packet.data = pkt.Data
+ } else {
+ // Raw packets need their ethernet headers prepended before
+ // queueing.
+ var linkHeader buffer.View
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ }
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ }
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(&packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+ ep.stats.PacketsReceived.Increment()
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (ep *endpoint) Info() tcpip.EndpointInfo {
+ ep.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := ep.TransportEndpointInfo
+ ep.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (ep *endpoint) Stats() tcpip.EndpointStats {
+ return &ep.stats
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
new file mode 100644
index 000000000..9b88f17e4
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
+ if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index fba598d51..00991ac8e 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -4,14 +4,14 @@ load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
+ name = "raw_packet_list",
+ out = "raw_packet_list.go",
package = "raw",
- prefix = "packet",
+ prefix = "rawPacket",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*packet",
- "Linker": "*packet",
+ "Element": "*rawPacket",
+ "Linker": "*rawPacket",
},
)
@@ -20,8 +20,8 @@ go_library(
srcs = [
"endpoint.go",
"endpoint_state.go",
- "packet_list.go",
"protocol.go",
+ "raw_packet_list.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
@@ -34,14 +34,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b4c660859..5aafe2615 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -17,8 +17,7 @@
//
// * manually write and inspect transport layer headers and payloads
// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
-// * optionally write and inspect network layer and link layer headers for
-// packets
+// * optionally write and inspect network layer headers of packets
//
// Raw sockets don't have any notion of ports, and incoming packets are
// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
@@ -38,15 +37,11 @@ import (
)
// +stateify savable
-type packet struct {
- packetEntry
+type rawPacket struct {
+ rawPacketEntry
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
@@ -72,7 +67,7 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ rcvList rawPacketList
rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
rcvClosed bool
@@ -90,7 +85,6 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
@@ -187,17 +181,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
return buffer.View{}, tcpip.ControlMessages{}, err
}
- packet := e.rcvList.Front()
- e.rcvList.Remove(packet)
- e.rcvBufSize -= packet.data.Size()
+ pkt := e.rcvList.Front()
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
e.rcvMu.Unlock()
if addr != nil {
- *addr = packet.senderAddr
+ *addr = pkt.senderAddr
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+ return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -344,13 +338,18 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
switch e.NetProto {
case header.IPv4ProtocolNumber:
if !e.associated {
- if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
+ if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
return 0, nil, err
}
break
}
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
return 0, nil, err
}
@@ -561,7 +560,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -602,16 +601,17 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- packet := &packet{
+ packet := &rawPacket{
senderAddr: tcpip.FullAddress{
NIC: route.NICID(),
Addr: route.RemoteAddress,
},
}
- combinedVV := netHeader.ToVectorisedView()
- combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
+ networkHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ combinedVV := networkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
packet.timestampNS = e.stack.NowNanoseconds()
e.rcvList.PushBack(packet)
@@ -643,3 +643,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index a6c7cc43a..33bfb56cd 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -20,15 +20,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// saveData saves packet.data field.
-func (p *packet) saveData() buffer.VectorisedView {
+// saveData saves rawPacket.data field.
+func (p *rawPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
// which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
-// loadData loads packet.data field.
-func (p *packet) loadData(data buffer.VectorisedView) {
+// loadData loads rawPacket.data field.
+func (p *rawPacket) loadData(data buffer.VectorisedView) {
// NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
// here because data.views is not guaranteed to be loaded by now. Plus,
// data.views will be allocated anyway so there really is little point
@@ -86,7 +86,9 @@ func (ep *endpoint) Resume(s *stack.Stack) {
}
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
- panic(err)
+ if ep.associated {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ panic(err)
+ }
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index a2512d666..f30aa2a4a 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -17,13 +17,19 @@ package raw
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/packet"
"gvisor.dev/gvisor/pkg/waiter"
)
-// EndpointFactory implements stack.UnassociatedEndpointFactory.
+// EndpointFactory implements stack.RawFactory.
type EndpointFactory struct{}
-// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index aed70e06f..3b353d56c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -28,6 +28,7 @@ go_library(
"forwarder.go",
"protocol.go",
"rcv.go",
+ "rcv_state.go",
"reno.go",
"sack.go",
"sack_scoreboard.go",
@@ -44,6 +45,7 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
@@ -51,6 +53,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
@@ -60,17 +63,9 @@ go_library(
],
)
-filegroup(
- name = "autogen",
- srcs = [
- "tcp_segment_list.go",
- ],
- visibility = ["//:sandbox"],
-)
-
go_test(
name = "tcp_test",
- size = "small",
+ size = "medium",
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 844959fa0..5422ae80c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -241,8 +242,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
+ // Now inherit any socket options that should be inherited from the
+ // listening endpoint.
+ // In case of Forwarder listenEP will be nil and hence this check.
+ if l.listenEP != nil {
+ l.listenEP.propagateInheritableOptions(n)
+ }
+
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -268,8 +276,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
- ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ isn := generateSecureISN(s.id, l.stack.Seed())
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
if err != nil {
return nil, err
}
@@ -288,7 +296,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
- h.resetToSynRcvd(cookie, irs, opts)
+ h.resetToSynRcvd(isn, irs, opts)
if err := h.execute(); err != nil {
ep.Close()
if l.listenEP != nil {
@@ -297,7 +305,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
ep.mu.Lock()
- ep.state = StateEstablished
+ ep.isConnectNotified = true
ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
@@ -349,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
}
+// propagateInheritableOptions propagates any options set on the listening
+// endpoint to the newly created endpoint.
+func (e *endpoint) propagateInheritableOptions(n *endpoint) {
+ e.mu.Lock()
+ n.userTimeout = e.userTimeout
+ e.mu.Unlock()
+}
+
// handleSynSegment is called in its own goroutine once the listening endpoint
// receives a SYN segment. It is responsible for completing the handshake and
// queueing the new endpoint for acceptance.
@@ -359,6 +375,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer decSynRcvdCount()
defer e.decSynRcvdCount()
defer s.decRef()
+
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -366,6 +383,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
return
}
ctx.removePendingEndpoint(n)
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
e.deliverAccepted(n)
}
@@ -399,8 +421,19 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
- switch s.flags {
- case header.TCPFlagSyn:
+ if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) {
+ // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
+ // must be sent in response to a SYN-ACK while in the listen
+ // state to prevent completing a handshake from an old SYN.
+ e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil)
+ return
+ }
+
+ // TODO(b/143300739): Use the userMSS of the listening socket
+ // for accepted sockets.
+
+ switch {
+ case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
// Only handle the syn if the following conditions hold
@@ -433,19 +466,18 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
- mss := mssForRoute(&s.route)
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: uint16(mss),
+ MSS: mssForRoute(&s.route),
}
e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
- case header.TCPFlagAck:
+ case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -459,6 +491,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
if !synCookiesInUse() {
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
// Send a reset as this is an ACK for which there is no
// half open connections and we are not using cookies
// yet.
@@ -519,7 +559,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
+ // We do not use transitionToStateEstablishedLocked here as there is
+ // no handshake state available when doing a SYN cookie based accept.
+ n.stack.Stats().TCP.CurrentEstablished.Increment()
n.state = StateEstablished
+ n.isConnectNotified = true
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -530,6 +574,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// number of goroutines as we do check before
// entering here that there was at least some
// space available in the backlog.
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5ea036bea..cdd69f360 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"sync"
"time"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -78,9 +80,6 @@ type handshake struct {
// mss is the maximum segment size received from the peer.
mss uint16
- // amss is the maximum segment size advertised by us to the peer.
- amss uint16
-
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
sndWndScale int
@@ -142,7 +141,32 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+ h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+}
+
+// generateSecureISN generates a secure Initial Sequence number based on the
+// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
+func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+ isnHasher := jenkins.Sum32(seed)
+ isnHasher.Write([]byte(id.LocalAddress))
+ isnHasher.Write([]byte(id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
+ isnHasher.Write(portBuf)
+ binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
+ isnHasher.Write(portBuf)
+ // The time period here is 64ns. This is similar to what linux uses
+ // generate a sequence number that overlaps less than one
+ // time per MSL (2 minutes).
+ //
+ // A 64ns clock ticks 10^9/64 = 15625000) times in a second.
+ // To wrap the whole 32 bit space would require
+ // 2^32/1562500 ~ 274 seconds.
+ //
+ // Which sort of guarantees that we won't reuse the ISN for a new
+ // connection for the same tuple for at least 274s.
+ isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ return seqnum.Value(isn)
}
// effectiveRcvWndScale returns the effective receive window scale to be used.
@@ -194,6 +218,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// acceptable if the ack field acknowledges the SYN.
if s.flagIsSet(header.TCPFlagRst) {
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
+ // was acceptable then signal the user "error: connection reset", drop
+ // the segment, enter CLOSED state, delete TCB, and return."
+ h.ep.mu.Lock()
+ h.ep.workerCleanup = true
+ h.ep.mu.Unlock()
+ // Although the RFC above calls out ECONNRESET, Linux actually returns
+ // ECONNREFUSED here so we do as well.
return tcpip.ErrConnectionRefused
}
return nil
@@ -228,6 +260,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// and the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
+
+ h.ep.mu.Lock()
+ h.ep.transitionToStateEstablishedLocked(h)
+ h.ep.mu.Unlock()
+
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
}
@@ -275,6 +312,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
+ // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a
+ // sequence number outside of the window causes an ACK with the proper seq
+ // number and "After sending the acknowledgment, drop the unacceptable
+ // segment and return."
+ if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd)
+ return nil
+ }
+
if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
// We received two SYN segments with different sequence
// numbers, so we reset this and restart the whole
@@ -319,6 +365,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+ h.ep.mu.Lock()
+ h.ep.transitionToStateEstablishedLocked(h)
+ h.ep.mu.Unlock()
+
return nil
}
@@ -445,7 +495,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
- h.ep.amss = mssForRoute(&h.ep.route)
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -607,19 +657,14 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-// sendTCP sends a TCP segment with the provided options via the provided
-// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
optLen := len(opts)
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
+ hdr := &pkt.Header
+ packetSize := pkt.DataSize
+ off := pkt.DataOffset
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ pkt.TransportHeader = buffer.View(tcp)
tcp.Encode(&header.TCPFields{
SrcPort: id.LocalPort,
DstPort: id.RemotePort,
@@ -631,7 +676,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
})
copy(tcp[header.TCPMinimumSize:], opts)
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
@@ -641,14 +686,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVV(data, xsum)
+ xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+}
+
+func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ mss := int(gso.MSS)
+ n := (data.Size() + mss - 1) / mss
+
+ // Allocate one big slice for all the headers.
+ hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
+ buf := make([]byte, n*hdrSize)
+ pkts := make([]tcpip.PacketBuffer, n)
+ for i := range pkts {
+ pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
+ }
+
+ size := data.Size()
+ off := 0
+ for i := 0; i < n; i++ {
+ packetSize := mss
+ if packetSize > size {
+ packetSize = size
+ }
+ size -= packetSize
+ pkts[i].DataOffset = off
+ pkts[i].DataSize = packetSize
+ pkts[i].Data = data
+ buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso)
+ off += packetSize
+ seq = seq.Add(seqnum.Size(packetSize))
+ }
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ if err != nil {
+ r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
+ }
+ r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent))
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ DataOffset: 0,
+ DataSize: data.Size(),
+ Data: data,
+ }
+ buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso)
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
@@ -754,10 +864,26 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
+ if e.state == StateEstablished || e.state == StateCloseWait {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
e.state = StateError
e.HardError = err
- if err != tcpip.ErrConnectionReset {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
+ // The exact sequence number to be used for the RST is the same as the
+ // one used by Linux. We need to handle the case of window being shrunk
+ // which can cause sndNxt to be outside the acceptable window on the
+ // receiver.
+ //
+ // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
+ // information.
+ sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ resetSeqNum := sndWndEnd
+ if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
+ resetSeqNum = e.snd.sndNxt
+ }
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
}
}
@@ -771,6 +897,99 @@ func (e *endpoint) completeWorkerLocked() {
}
}
+// transitionToStateEstablisedLocked transitions a given endpoint
+// to an established state using the handshake parameters provided.
+// It also initializes sender/receiver if required.
+func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
+ if e.snd == nil {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+ }
+ if e.rcv == nil {
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+ }
+ h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.state = StateEstablished
+}
+
+// transitionToStateCloseLocked ensures that the endpoint is
+// cleaned up from the transport demuxer, "before" moving to
+// StateClose. This will ensure that no packet will be
+// delivered to this endpoint from the demuxer when the endpoint
+// is transitioned to StateClose.
+func (e *endpoint) transitionToStateCloseLocked() {
+ if e.state == StateClose {
+ return
+ }
+ e.cleanupLocked()
+ e.state = StateClose
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+}
+
+// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
+// segment to any other endpoint other than the current one. This is called
+// only when the endpoint is in StateClose and we want to deliver the segment
+// to any other listening endpoint. We reply with RST if we cannot find one.
+func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil {
+ replyWithReset(s)
+ s.decRef()
+ return
+ }
+ ep.(*endpoint).enqueueSegment(s)
+}
+
+func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ e.mu.Lock()
+ switch e.state {
+ // In case of a RST in CLOSE-WAIT linux moves
+ // the socket to closed state with an error set
+ // to indicate EPIPE.
+ //
+ // Technically this seems to be at odds w/ RFC.
+ // As per https://tools.ietf.org/html/rfc793#section-2.7
+ // page 69 the behavior for a segment arriving
+ // w/ RST bit set in CLOSE-WAIT is inlined below.
+ //
+ // ESTABLISHED
+ // FIN-WAIT-1
+ // FIN-WAIT-2
+ // CLOSE-WAIT
+
+ // If the RST bit is set then, any outstanding RECEIVEs and
+ // SEND should receive "reset" responses. All segment queues
+ // should be flushed. Users should also receive an unsolicited
+ // general "connection reset" signal. Enter the CLOSED state,
+ // delete the TCB, and return.
+ case StateCloseWait:
+ e.transitionToStateCloseLocked()
+ e.HardError = tcpip.ErrAborted
+ e.mu.Unlock()
+ return false, nil
+ default:
+ e.mu.Unlock()
+ return false, tcpip.ErrConnectionReset
+ }
+ }
+ return true, nil
+}
+
// handleSegments pulls segments from the queue and processes them. It returns
// no error if the protocol loop should continue, an error otherwise.
func (e *endpoint) handleSegments() *tcpip.Error {
@@ -788,14 +1007,34 @@ func (e *endpoint) handleSegments() *tcpip.Error {
}
if s.flagIsSet(header.TCPFlagRst) {
- if e.rcv.acceptable(s.sequenceNumber, 0) {
- // RFC 793, page 37 states that "in all states
- // except SYN-SENT, all reset (RST) segments are
- // validated by checking their SEQ-fields." So
- // we only process it if it's acceptable.
- s.decRef()
- return tcpip.ErrConnectionReset
+ if ok, err := e.handleReset(s); !ok {
+ return err
}
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
} else if s.flagIsSet(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
@@ -804,7 +1043,33 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// RFC 793, page 41 states that "once in the ESTABLISHED
// state all segments must carry current acknowledgment
// information."
- e.rcv.handleRcvdSegment(s)
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
+ }
+ if drop {
+ s.decRef()
+ continue
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return nil
+ }
e.snd.handleRcvdSegment(s)
}
s.decRef()
@@ -830,14 +1095,27 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.mu.RLock()
+ userTimeout := e.userTimeout
+ e.mu.RUnlock()
+
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
+ // If a userTimeout is set then abort the connection if it is
+ // exceeded.
+ if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
@@ -854,7 +1132,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
// whether it is enabled for this endpoint.
func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
e.keepalive.Lock()
- defer e.keepalive.Unlock()
if receivedData {
e.keepalive.unacked = 0
}
@@ -862,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
// data to send.
if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
+ e.keepalive.Unlock()
return
}
if e.keepalive.unacked > 0 {
@@ -869,6 +1147,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
} else {
e.keepalive.timer.enable(e.keepalive.idle)
}
+ e.keepalive.Unlock()
}
// disableKeepaliveTimer stops the keepalive timer.
@@ -903,7 +1182,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
-
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -932,21 +1210,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return err
}
-
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // boot strap the auto tuning algorithm. Starting at zero will
- // result in a large step function on the first proper causing
- // the window to just go to a really large value after the first
- // RTT itself.
- e.rcvAutoParams.prevCopied = initialRcvWnd
- e.rcvListMu.Unlock()
}
e.keepalive.timer.init(&e.keepalive.waker)
@@ -954,7 +1217,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
- e.state = StateEstablished
drained := e.drainDone != nil
e.mu.Unlock()
if drained {
@@ -985,13 +1247,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
{
w: &closeWaker,
f: func() *tcpip.Error {
- return tcpip.ErrConnectionAborted
+ // This means the socket is being closed due
+ // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // mark the socket as closed.
+ e.mu.Lock()
+ e.transitionToStateCloseLocked()
+ e.mu.Unlock()
+ return nil
},
},
{
w: &e.snd.resendWaker,
f: func() *tcpip.Error {
if !e.snd.retransmitTimerExpired() {
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
return nil
@@ -1028,17 +1297,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
}
+
if n&notifyClose != 0 && closeTimer == nil {
- // Reset the connection 3 seconds after
- // the endpoint has been closed.
- //
- // The timer could fire in background
- // when the endpoint is drained. That's
- // OK as the loop here will not honor
- // the firing until the undrain arrives.
- closeTimer = time.AfterFunc(3*time.Second, func() {
- closeWaker.Assert()
- })
+ e.mu.Lock()
+ if e.state == StateFinWait2 && e.closed {
+ // The socket has been closed and we are in FIN_WAIT2
+ // so start the FIN_WAIT2 timer.
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
+ closeWaker.Assert()
+ })
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+ e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1054,12 +1324,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return err
}
}
- if e.state != StateError {
+ if e.state != StateClose && e.state != StateError {
+ // Only block the worker if the endpoint
+ // is not in closed state or error state.
close(e.drainDone)
<-e.undrain
}
}
+ if n&notifyTickleWorker != 0 {
+ // Just a tickle notification. No need to do
+ // anything.
+ return nil
+ }
+
return nil
},
},
@@ -1086,15 +1364,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.rcvListMu.Unlock()
- e.mu.RLock()
+ e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
- e.mu.RUnlock()
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+
+ for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+ e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
@@ -1110,15 +1389,168 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return nil
}
+ e.mu.Lock()
+ }
+
+ state := e.state
+ e.mu.Unlock()
+ var reuseTW func()
+ if state == StateTimeWait {
+ // Disable close timer as we now entering real TIME_WAIT.
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ // Mark the current sleeper done so as to free all associated
+ // wakers.
+ s.Done()
+ // Wake up any waiters before we enter TIME_WAIT.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
if e.state != StateError {
- e.state = StateClose
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ e.transitionToStateCloseLocked()
}
+
// Lock released below.
epilogue()
+ // epilogue removes the endpoint from the transport-demuxer and
+ // unlocks e.mu. Now that no new segments can get enqueued to this
+ // endpoint, try to re-match the segment to a different endpoint
+ // as the current endpoint is closed.
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+
+ // A new SYN was received during TIME_WAIT and we need to abort
+ // the timewait and redirect the segment to the listener queue
+ if reuseTW != nil {
+ reuseTW()
+ }
+
return nil
}
+
+// handleTimeWaitSegments processes segments received during TIME_WAIT
+// state.
+func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+ extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
+ if newSyn {
+ info := e.EndpointInfo.TransportEndpointInfo
+ newID := info.ID
+ newID.RemoteAddress = ""
+ newID.RemotePort = 0
+ netProtos := []tcpip.NetworkProtocolNumber{info.NetProto}
+ // If the local address is an IPv4 address then also
+ // look for IPv6 dual stack endpoints that might be
+ // listening on the local address.
+ if newID.LocalAddress.To4() != "" {
+ netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
+ }
+ for _, netProto := range netProtos {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ tcpEP := listenEP.(*endpoint)
+ if EndpointState(tcpEP.State()) == StateListen {
+ reuseTW = func() {
+ tcpEP.enqueueSegment(s)
+ }
+ // We explicitly do not decRef
+ // the segment as it's still
+ // valid and being reflected to
+ // a listening endpoint.
+ return false, reuseTW
+ }
+ }
+ }
+ }
+ if extTW {
+ extendTimeWait = true
+ }
+ s.decRef()
+ }
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ return extendTimeWait, nil
+}
+
+// doTimeWait is responsible for handling the TCP behaviour once a socket
+// enters the TIME_WAIT state. Optionally it can return a closure that
+// should be executed after releasing the endpoint registrations. This is
+// done in cases where a new SYN is received during TIME_WAIT that carries
+// a sequence number larger than one see on the connection.
+func (e *endpoint) doTimeWait() (twReuse func()) {
+ // Trigger a 2 * MSL time wait state. During this period
+ // we will drop all incoming segments.
+ // NOTE: On Linux this is not configurable and is fixed at 60 seconds.
+ timeWaitDuration := DefaultTCPTimeWaitTimeout
+
+ // Get the stack wide configuration.
+ var tcpTW tcpip.TCPTimeWaitTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil {
+ timeWaitDuration = time.Duration(tcpTW)
+ }
+
+ const newSegment = 1
+ const notification = 2
+ const timeWaitDone = 3
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.newSegmentWaker, newSegment)
+ s.AddWaker(&e.notificationWaker, notification)
+
+ var timeWaitWaker sleep.Waker
+ s.AddWaker(&timeWaitWaker, timeWaitDone)
+ timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ defer timeWaitTimer.Stop()
+
+ for {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ switch v {
+ case newSegment:
+ extendTimeWait, reuseTW := e.handleTimeWaitSegments()
+ if reuseTW != nil {
+ return reuseTW
+ }
+ if extendTimeWait {
+ timeWaitTimer.Reset(timeWaitDuration)
+ }
+ case notification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ // Ignore extending TIME_WAIT during a
+ // save. For sockets in TIME_WAIT we just
+ // terminate the TIME_WAIT early.
+ e.handleTimeWaitSegments()
+ }
+ close(e.drainDone)
+ <-e.undrain
+ return nil
+ }
+ case timeWaitDone:
+ return nil
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a1b784b49..fe629aa40 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tmutex"
@@ -121,6 +122,11 @@ const (
notifyReset
notifyKeepaliveChanged
notifyMSSChanged
+ // notifyTickleWorker is used to tickle the protocol main loop during a
+ // restore after we update the endpoint state to the correct one. This
+ // ensures the loop terminates if the final state of the endpoint is
+ // say TIME_WAIT.
+ notifyTickleWorker
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -287,6 +293,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -319,6 +326,11 @@ type endpoint struct {
state EndpointState `state:".(EndpointState)"`
+ // origEndpointState is only used during a restore phase to save the
+ // endpoint state at restore time as the socket is moved to it's correct
+ // state.
+ origEndpointState EndpointState `state:"nosave"`
+
isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
@@ -330,6 +342,11 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // Values used to reserve a port or register a transport endpoint
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
@@ -411,7 +428,7 @@ type endpoint struct {
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
- userMSS int
+ userMSS uint16
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
@@ -458,6 +475,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // userTimeout if non-zero specifies a user specified timeout for
+ // a connection w/ pending data to send. A connection that has pending
+ // unacked data will be forcibily aborted if the timeout is reached
+ // without any data being acked.
+ userTimeout time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -502,6 +525,36 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats Stats `state:"nosave"`
+
+ // tcpLingerTimeout is the maximum amount of a time a socket
+ // a socket stays in TIME_WAIT state before being marked
+ // closed.
+ tcpLingerTimeout time.Duration
+
+ // closed indicates that the user has called closed on the
+ // endpoint and at this point the endpoint is only around
+ // to complete the TCP shutdown.
+ closed bool
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// calculateAdvertisedMSS calculates the MSS to advertise.
+//
+// If userMSS is non-zero and is not greater than the maximum possible MSS for
+// r, it will be used; otherwise, the maximum possible MSS will be used.
+func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+ // The maximum possible MSS is dependent on the route.
+ maxMSS := mssForRoute(&r)
+
+ if userMSS != 0 && userMSS < maxMSS {
+ return userMSS
+ }
+
+ return maxMSS
}
// StopWork halts packet processing. Only to be used in tests.
@@ -550,6 +603,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
interval: 75 * time.Second,
count: 9,
},
+ uniqueID: s.UniqueID(),
}
var ss SendBufferSizeOption
@@ -572,6 +626,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.rcvAutoParams.disabled = !bool(mrb)
}
+ var de DelayEnabled
+ if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
+ e.SetSockOptInt(tcpip.DelayOption, 1)
+ }
+
+ var tcpLT tcpip.TCPLingerTimeoutOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil {
+ e.tcpLingerTimeout = time.Duration(tcpLT)
+ }
+
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -659,6 +723,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
+ e.mu.Lock()
+ closed := e.closed
+ e.mu.Unlock()
+ if closed {
+ return
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -671,14 +742,18 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
+ // Mark endpoint as closed.
+ e.closed = true
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
@@ -704,9 +779,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
go func() {
defer close(done)
for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
n.Close()
}
}()
@@ -732,16 +805,19 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
}
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
e.route.Release()
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -752,7 +828,9 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
- routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+
+ // Use the user supplied MSS, if available.
+ routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
@@ -1133,16 +1211,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
- default:
- return nil
- }
-}
-
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
- const inetECNMask = 3
- switch v := opt.(type) {
case tcpip.DelayOption:
if v == 0 {
atomic.StoreUint32(&e.delay, 0)
@@ -1154,6 +1222,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+ switch v := opt.(type) {
case tcpip.CorkOption:
if v == 0 {
atomic.StoreUint32(&e.cork, 0)
@@ -1184,9 +1262,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = 0
return nil
}
- for nicid, nic := range e.stack.NICInfo() {
+ for nicID, nic := range e.stack.NICInfo() {
if nic.Name == string(v) {
- e.bindToDevice = nicid
+ e.bindToDevice = nicID
return nil
}
}
@@ -1206,7 +1284,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
e.mu.Lock()
- e.userMSS = int(userMSS)
+ e.userMSS = uint16(userMSS)
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
@@ -1262,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
return nil
+ case tcpip.TCPUserTimeoutOption:
+ e.mu.Lock()
+ e.userTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -1319,6 +1403,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ if v < 0 {
+ // Same as effectively disabling TCPLinger timeout.
+ v = 0
+ }
+ var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
+ // We were unable to retrieve a stack config, just use
+ // the DefaultTCPLingerTimeout.
+ if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
+ stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ }
+ }
+ // Cap it to the stack wide TCPLinger timeout.
+ if v > stkTCPLingerTimeout {
+ v = stkTCPLingerTimeout
+ }
+ e.tcpLingerTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1345,6 +1451,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+
case tcpip.SendBufferSizeOption:
e.sndBufMu.Lock()
v := e.sndBufSize
@@ -1357,8 +1464,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
e.rcvListMu.Unlock()
return v, nil
+ case tcpip.DelayOption:
+ var o int
+ if v := atomic.LoadUint32(&e.delay); v != 0 {
+ o = 1
+ }
+ return o, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
- return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -1379,13 +1494,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.DelayOption:
- *o = 0
- if v := atomic.LoadUint32(&e.delay); v != 0 {
- *o = 1
- }
- return nil
-
case *tcpip.CorkOption:
*o = 0
if v := atomic.LoadUint32(&e.cork); v != 0 {
@@ -1496,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.keepalive.Unlock()
return nil
+ case *tcpip.TCPUserTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPUserTimeoutOption(e.userTimeout)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
*o = 1
@@ -1530,6 +1644,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.RUnlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1602,7 +1722,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnected
}
- nicid := addr.NIC
+ nicID := addr.NIC
switch e.state {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
@@ -1611,11 +1731,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
break
}
- if nicid != 0 && nicid != e.boundNICID {
+ if nicID != 0 && nicID != e.boundNICID {
return tcpip.ErrNoRoute
}
- nicid = e.boundNICID
+ nicID = e.boundNICID
case StateInitial:
// Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
@@ -1634,7 +1754,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -1649,7 +1769,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
if err != nil {
return err
}
@@ -1664,7 +1784,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.PortSeed())
+ h := jenkins.Sum32(e.stack.Seed())
h.Write([]byte(e.ID.LocalAddress))
h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
@@ -1678,15 +1798,18 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
// reusePort is false below because connect cannot reuse a port even if
// reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
return false, nil
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
+ switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
+ // Port picking successful. Save the details of
+ // the selected port.
e.ID = id
+ e.boundBindToDevice = e.bindToDevice
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@@ -1702,14 +1825,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
}
e.isRegistered = true
e.state = StateConnecting
e.route = r.Clone()
- e.boundNICID = nicid
+ e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -1729,6 +1852,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.state = StateEstablished
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
}
if run {
@@ -1749,9 +1873,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
- defer e.mu.Unlock()
e.shutdownFlags |= flags
-
+ finQueued := false
switch {
case e.state.connected():
// Close for read.
@@ -1766,6 +1889,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
e.notifyProtocolGoroutine(notifyReset)
+ e.mu.Unlock()
return nil
}
}
@@ -1784,14 +1908,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
-
+ finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
-
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
}
case e.state == StateListen:
@@ -1799,11 +1920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
}
-
default:
+ e.mu.Unlock()
return tcpip.ErrNotConnected
}
-
+ e.mu.Unlock()
+ if finQueued {
+ if e.workMu.TryLock() {
+ e.handleClose()
+ e.workMu.Unlock()
+ } else {
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+ }
return nil
}
@@ -1844,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
+ if e.state == StateInitial {
+ // The listen is called on an unbound socket, the socket is
+ // automatically bound to a random free port with the local
+ // address set to INADDR_ANY.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return err
+ }
+ }
+
// Endpoint must be bound before it can transition to listen mode.
if e.state != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
@@ -1851,7 +1990,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
return err
}
@@ -1895,12 +2034,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrWouldBlock
}
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
-
- return n, wq, nil
+ return n, n.waiterQueue, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
@@ -1908,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
+ return e.bindLocked(addr)
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
@@ -1932,26 +2070,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
+ flags := ports.Flags{
+ LoadBalanced: e.reusePort,
+ }
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
if err != nil {
return err
}
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = flags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func(bindToDevice tcpip.NICID) {
+ defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
e.ID.LocalAddress = ""
e.boundNICID = 0
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
- }(e.bindToDevice)
+ }(e.boundPortFlags, e.boundBindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
@@ -2001,8 +2146,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
- s := newSegment(r, id, vv)
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ s := newSegment(r, id, pkt)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -2025,6 +2170,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.stack.Stats().TCP.ResetsReceived.Increment()
}
+ e.enqueueSegment(s)
+}
+
+func (e *endpoint) enqueueSegment(s *segment) {
// Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) {
e.newSegmentWaker.Assert()
@@ -2037,7 +2186,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2327,11 +2476,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
return s
}
-func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityGSO == 0 {
- return
- }
-
+func (e *endpoint) initHardwareGSO() {
gso := &stack.GSO{}
switch e.route.NetProto {
case header.IPv4ProtocolNumber:
@@ -2349,6 +2494,18 @@ func (e *endpoint) initGSO() {
e.gso = gso
}
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ e.initHardwareGSO()
+ } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ e.gso = &stack.GSO{
+ MaxSize: e.route.GSOMaxSize(),
+ Type: stack.GSOSW,
+ NeedsCsum: false,
+ }
+ }
+}
+
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
@@ -2371,6 +2528,23 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements stack.TransportEndpoint.Wait.
+func (e *endpoint) Wait() {
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
+ defer e.waiterQueue.EventUnregister(&waitEntry)
+ for {
+ e.mu.Lock()
+ running := e.workerRunning
+ e.mu.Unlock()
+ if !running {
+ break
+ }
+ <-notifyCh
+ }
+}
+
func mssForRoute(r *stack.Route) uint16 {
+ // TODO(b/143359391): Respect TCP Min and Max size.
return uint16(r.MTU() - header.TCPMinimumSize)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index eae17237e..7aa4c3f0e 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() {
}
fallthrough
case StateError, StateClose:
- for e.state == StateError && e.workerRunning {
+ for (e.state == StateError || e.state == StateClose) && e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
@@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
+ // Freeze segment queue before registering to prevent any segments
+ // from being delivered while it is being restored.
+ e.origEndpointState = e.state
+ // Restore the endpoint to InitialState as it will be moved to
+ // its origEndpointState during Resume.
+ e.state = StateInitial
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
+ state := e.origEndpointState
- state := e.state
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
bind := func() {
- e.state = StateInitial
if len(e.BindAddr) == 0 {
e.BindAddr = e.ID.LocalAddress
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
+ addr := e.BindAddr
+ port := e.ID.LocalPort
+ if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
+ panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
}
}
@@ -217,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
+ e.mu.Lock()
+ e.state = e.origEndpointState
+ closed := e.closed
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ if state == StateFinWait2 && closed {
+ // If the endpoint has been closed then make sure we notify so
+ // that the FIN_WAIT2 timer is started after a restore.
+ e.notifyProtocolGoroutine(notifyClose)
+ }
connectedLoading.Done()
case StateListen:
tcpip.AsyncLoading.Add(1)
@@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
}
- fallthrough
+ e.state = StateClose
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
case StateError:
+ e.state = StateError
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 63666f0b3..4983bca81 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -18,7 +18,6 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index db40785d3..bc718064c 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,6 +23,7 @@ package tcp
import (
"strings"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -54,12 +55,23 @@ const (
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
MaxUnprocessedSegments = 300
+
+ // DefaultTCPLingerTimeout is the amount of time that sockets linger in
+ // FIN_WAIT_2 state before being marked closed.
+ DefaultTCPLingerTimeout = 60 * time.Second
+
+ // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
+ // in TIME_WAIT state before being marked closed.
+ DefaultTCPTimeWaitTimeout = 60 * time.Second
)
// SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
+// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
+type DelayEnabled bool
+
// SendBufferSizeOption allows the default, min and max send buffer sizes for
// TCP endpoints to be queried or configured.
type SendBufferSizeOption struct {
@@ -84,11 +96,14 @@ const (
type protocol struct {
mu sync.Mutex
sackEnabled bool
+ delayEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
+ tcpLingerTimeout time.Duration
+ tcpTimeWaitTimeout time.Duration
}
// Number returns the tcp protocol number.
@@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
@@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
@@ -147,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
func replyWithReset(s *segment) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
+ ack := seqnum.Value(0)
+ flags := byte(header.TCPFlagRst)
+ // As per RFC 793 page 35 (Reset Generation)
+ // 1. If the connection does not exist (CLOSED) then a reset is sent
+ // in response to any incoming segment except another reset. In
+ // particular, SYNs addressed to a non-existent connection are rejected
+ // by this means.
+
+ // If the incoming segment has an ACK field, the reset takes its
+ // sequence number from the ACK field of the segment, otherwise the
+ // reset has sequence number zero and the ACK field is set to the sum
+ // of the sequence number and segment length of the incoming segment.
+ // The connection remains in the CLOSED state.
if s.flagIsSet(header.TCPFlagAck) {
seq = s.ackNumber
+ } else {
+ flags |= header.TCPFlagAck
+ ack = s.sequenceNumber.Add(s.logicalLen())
}
-
- ack := s.sequenceNumber.Add(s.logicalLen())
-
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
// SetOption implements TransportProtocol.SetOption.
@@ -165,6 +193,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case DelayEnabled:
+ p.mu.Lock()
+ p.delayEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
case SendBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
@@ -202,6 +236,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpLingerTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpTimeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -216,6 +268,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *DelayEnabled:
+ p.mu.Lock()
+ *v = DelayEnabled(p.delayEnabled)
+ p.mu.Unlock()
+ return nil
+
case *SendBufferSizeOption:
p.mu.Lock()
*v = p.sendBufferSize
@@ -246,6 +304,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -258,5 +328,7 @@ func NewProtocol() stack.TransportProtocol {
recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
+ tcpLingerTimeout: DefaultTCPLingerTimeout,
+ tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index e90f9a7d9..0a5534959 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -18,6 +18,7 @@ import (
"container/heap"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -49,16 +50,20 @@ type receiver struct {
pendingRcvdSegments segmentHeap
pendingBufUsed seqnum.Size
pendingBufSize seqnum.Size
+
+ // Time when the last ack was received.
+ lastRcvdAckTime time.Time `state:".(unixTime)"`
}
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
return &receiver{
- ep: ep,
- rcvNxt: irs + 1,
- rcvAcc: irs.Add(rcvWnd + 1),
- rcvWnd: rcvWnd,
- rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: pendingBufSize,
+ lastRcvdAckTime: time.Now(),
}
}
@@ -204,15 +209,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
switch r.ep.state {
case StateFinWait1:
r.ep.state = StateFinWait2
+ // Notify protocol goroutine that we have received an
+ // ACK to our FIN so that it can start the FIN_WAIT2
+ // timer to abort connection if the other side does
+ // not close within 2MSL.
+ r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
r.ep.state = StateTimeWait
case StateLastAck:
- r.ep.state = StateClose
+ r.ep.transitionToStateCloseLocked()
}
r.ep.mu.Unlock()
}
@@ -253,25 +263,110 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-// handleRcvdSegment handles TCP segments directed at the connection managed by
-// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+ r.ep.rcvListMu.Lock()
+ rcvClosed := r.ep.rcvClosed || r.closed
+ r.ep.rcvListMu.Unlock()
+
+ // If we are in one of the shutdown states then we need to do
+ // additional checks before we try and process the segment.
+ switch state {
+ case StateCloseWait, StateClosing, StateLastAck:
+ if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ s.decRef()
+ // Just drop the segment as we have
+ // already received a FIN and this
+ // segment is after the sequence number
+ // for the FIN.
+ return true, nil
+ }
+ fallthrough
+ case StateFinWait1:
+ fallthrough
+ case StateFinWait2:
+ // If we are closed for reads (either due to an
+ // incoming FIN or the user calling shutdown(..,
+ // SHUT_RD) then any data past the rcvNxt should
+ // trigger a RST.
+ endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ if state == StateFinWait1 {
+ break
+ }
+
+ // If it's a retransmission of an old data segment
+ // or a pure ACK then allow it.
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ s.logicalLen() == 0 {
+ break
+ }
+
+ // In FIN-WAIT2 if the socket is fully
+ // closed(not owned by application on our end
+ // then the only acceptable segment is a
+ // FIN. Since FIN can technically also carry
+ // data we verify that the segment carrying a
+ // FIN ends at exactly e.rcvNxt+1.
+ //
+ // From RFC793 page 25.
+ //
+ // For sequence number purposes, the SYN is
+ // considered to occur before the first actual
+ // data octet of the segment in which it occurs,
+ // while the FIN is considered to occur after
+ // the last actual data octet in a segment in
+ // which it occurs.
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ }
+
// We don't care about receive processing anymore if the receive side
// is closed.
- if r.closed {
- return
+ //
+ // NOTE: We still want to permit a FIN as it's possible only our
+ // end has closed and the peer is yet to send a FIN. Hence we
+ // compare only the payload.
+ segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ return true, nil
+ }
+ return false, nil
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+ r.ep.mu.RLock()
+ state := r.ep.state
+ closed := r.ep.closed
+ r.ep.mu.RUnlock()
+
+ if state != StateEstablished {
+ drop, err := r.handleRcvdSegmentClosing(s, state, closed)
+ if drop || err != nil {
+ return drop, err
+ }
}
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// If the sequence number range is outside the acceptable range, just
- // send an ACK. This is according to RFC 793, page 37.
+ // send an ACK and stop further processing of the segment.
+ // This is according to RFC 793, page 68.
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
- return
+ return true, nil
}
+ // Store the time of the last ack.
+ r.lastRcvdAckTime = time.Now()
+
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
@@ -288,7 +383,7 @@ func (r *receiver) handleRcvdSegment(s *segment) {
// have to retransmit.
r.ep.snd.sendAck()
}
- return
+ return false, nil
}
// Since we consumed a segment update the receiver's RTT estimate
@@ -315,4 +410,67 @@ func (r *receiver) handleRcvdSegment(s *segment) {
r.pendingBufUsed -= s.logicalLen()
s.decRef()
}
+ return false, nil
+}
+
+// handleTimeWaitSegment handles inbound segments received when the endpoint
+// has entered the TIME_WAIT state.
+func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) {
+ segSeq := s.sequenceNumber
+ segLen := seqnum.Size(s.data.Size())
+
+ // Just silently drop any RST packets in TIME_WAIT. We do not support
+ // TIME_WAIT assasination as a result we confirm w/ fix 1 as described
+ // in https://tools.ietf.org/html/rfc1337#section-3.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return false, false
+ }
+
+ // If it's a SYN and the sequence number is higher than any seen before
+ // for this connection then try and redirect it to a listening endpoint
+ // if available.
+ //
+ // RFC 1122:
+ // "When a connection is [...] on TIME-WAIT state [...]
+ // [a TCP] MAY accept a new SYN from the remote TCP to
+ // reopen the connection directly, if it:
+
+ // (1) assigns its initial sequence number for the new
+ // connection to be larger than the largest sequence
+ // number it used on the previous connection incarnation,
+ // and
+
+ // (2) returns to TIME-WAIT state if the SYN turns out
+ // to be an old duplicate".
+ if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+
+ return false, true
+ }
+
+ // Drop the segment if it does not contain an ACK.
+ if !s.flagIsSet(header.TCPFlagAck) {
+ return false, false
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if r.ep.sendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ }
+
+ if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ // If it's a FIN-ACK then resetTimeWait and send an ACK, as it
+ // indicates our final ACK could have been lost.
+ r.ep.snd.sendAck()
+ return true, false
+ }
+
+ // If the sequence number range is outside the acceptable range or
+ // carries data then just send an ACK. This is according to RFC 793,
+ // page 37.
+ //
+ // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
+ if segSeq != r.rcvNxt || segLen != 0 {
+ r.ep.snd.sendAck()
+ }
+ return false, false
}
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go
new file mode 100644
index 000000000..2bf21a2e7
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveLastRcvdAckTime is invoked by stateify.
+func (r *receiver) saveLastRcvdAckTime() unixTime {
+ return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
+}
+
+// loadLastRcvdAckTime is invoked by stateify.
+func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
+ r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index ea725d513..1c10da5ca 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -60,13 +61,13 @@ type segment struct {
xmitTime time.Time `state:".(unixTime)"`
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
- s.data = vv.Clone(s.views[:])
+ s.data = pkt.Data.Clone(s.views[:])
s.rcvdTime = time.Now()
return s
}
@@ -99,8 +100,14 @@ func (s *segment) clone() *segment {
return t
}
-func (s *segment) flagIsSet(flag uint8) bool {
- return (s.flags & flag) != 0
+// flagIsSet checks if at least one flag in flags is set in s.flags.
+func (s *segment) flagIsSet(flags uint8) bool {
+ return s.flags&flags != 0
+}
+
+// flagsAreSet checks if all flags in flags are set in s.flags.
+func (s *segment) flagsAreSet(flags uint8) bool {
+ return s.flags&flags == flags
}
func (s *segment) decRef() {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 8332a0179..8a947dc66 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -28,8 +28,11 @@ import (
)
const (
- // minRTO is the minimum allowed value for the retransmit timeout.
- minRTO = 200 * time.Millisecond
+ // MinRTO is the minimum allowed value for the retransmit timeout.
+ MinRTO = 200 * time.Millisecond
+
+ // MaxRTO is the maximum allowed value for the retransmit timeout.
+ MaxRTO = 120 * time.Second
// InitialCwnd is the initial congestion window.
InitialCwnd = 10
@@ -134,6 +137,10 @@ type sender struct {
// rttMeasureTime is the time when the rttMeasureSeqNum was sent.
rttMeasureTime time.Time `state:".(unixTime)"`
+ // firstRetransmittedSegXmitTime is the original transmit time of
+ // the first segment that was retransmitted due to RTO expiration.
+ firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+
closed bool
writeNext *segment
writeList segmentList
@@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) {
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
- if s.rto < minRTO {
- s.rto = minRTO
+ if s.rto < MinRTO {
+ s.rto = MinRTO
}
}
@@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool {
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
- // Give up if we've waited more than a minute since the last resend.
- if s.rto >= 60*time.Second {
+ // Give up if we've waited more than a minute since the last resend or
+ // if a user time out is set and we have exceeded the user specified
+ // timeout since the first retransmission.
+ s.ep.mu.RLock()
+ uto := s.ep.userTimeout
+ s.ep.mu.RUnlock()
+
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ // We store the original xmitTime of the segment that we are
+ // about to retransmit as the retransmission time. This is
+ // required as by the time the retransmitTimer has expired the
+ // segment has already been sent and unacked for the RTO at the
+ // time the segment was sent.
+ s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
+ }
+
+ elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ remaining := MaxRTO
+ if uto != 0 {
+ // Cap to the user specified timeout if one is specified.
+ remaining = uto - elapsed
+ }
+
+ if remaining <= 0 || s.rto >= MaxRTO {
return false
}
@@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool {
// below.
s.rto *= 2
+ // Cap RTO to remaining time.
+ if s.rto > remaining {
+ s.rto = remaining
+ }
+
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
//
// Retransmit timeouts:
@@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// RFC 6298 Rule 5.3
if s.sndUna == s.sndNxt {
s.outstanding = 0
+ // Reset firstRetransmittedSegXmitTime to the zero value.
+ s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
}
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index 12eff8afc..8b20c3455 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) {
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
}
+
+// saveFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
+ return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
+}
+
+// loadFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
+ s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6d022a266..e8fe4dab5 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) {
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
}
+
+ // Call Connect again to retreive the handshake failure status
+ // and stats updates.
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ }
+
+ if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func TestConnectIncrementActiveConnection(t *testing.T) {
@@ -206,17 +220,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -256,7 +271,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -264,6 +279,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
}
}
+ // Lower stackwide TIME_WAIT timeout so that the reservations
+ // are released instantly on Close.
+ tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err)
+ }
+
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
@@ -285,6 +307,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Get the ACK to the FIN we just sent.
c.GetPacket()
+ // Since an active close was done we need to wait for a little more than
+ // tcpLingerTimeout for the port reservations to be released and the
+ // socket to move to a CLOSED state.
+ time.Sleep(20 * time.Millisecond)
+
// Now resend the same ACK, this ACK should generate a RST as there
// should be no endpoint in SYN-RCVD state and we are not using
// syn-cookies yet. The reason we send the same ACK is we need a valid
@@ -296,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
}
func TestTCPResetsReceivedIncrement(t *testing.T) {
@@ -376,6 +403,13 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLinger to 3 seconds so that sockets are marked closed
+ // after 3 second in FIN_WAIT2 state.
+ tcpLingerTimeout := 3 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -396,12 +430,24 @@ func TestConnectResetAfterClose(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1),
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Wait for the ep to give up waiting for a FIN.
+ time.Sleep(tcpLingerTimeout + 1*time.Second)
+
+ // Now send an ACK and it should trigger a RST as the endpoint should
+ // not exist anymore.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
- // Wait for the ep to give up waiting for a FIN, and send a RST.
- time.Sleep(3 * time.Second)
for {
b := c.GetPacket()
tcpHdr := header.TCP(header.IPv4(b).Payload())
@@ -413,15 +459,128 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
break
}
}
+// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
+// when the endpoint transitions to StateClose. The in-flight segments would be
+// re-enqueued to a any listening endpoint.
+func TestClosingWithEnqueuedSegments(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Get the ACK for the FIN we sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
+ ep.Close()
+
+ // Get the FIN
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Pause the endpoint`s protocolMainLoop.
+ ep.(interface{ StopWork() }).StopWork()
+
+ // Enqueue last ACK followed by an ACK matching the endpoint
+ //
+ // Send Last ACK for LAST_ACK --> CLOSED
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Send a packet with ACK set, this would generate RST when
+ // not using SYN cookies as in this test.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 792,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Unpause endpoint`s protocolMainLoop.
+ ep.(interface{ ResumeWork() }).ResumeWork()
+
+ // Wait for the protocolMainLoop to resume and update state.
+ time.Sleep(10 * time.Millisecond)
+
+ // Expect the endpoint to be closed.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+
+ // Check if the endpoint was moved to CLOSED and netstack a reset in
+ // response to the ACK packet that we sent after last-ACK.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+}
+
func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -474,6 +633,352 @@ func TestSimpleReceive(t *testing.T) {
)
}
+// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
+// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv4 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
+// creating a new active IPv6 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv6 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv4.
+func TestTCPAckBeforeAcceptV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv6.
+func TestTCPAckBeforeAcceptV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestSendRstOnListenerRxAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true /* v6Only */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestListenShutdown tests for the listening endpoint not processing
+// any receive when it is on read shutdown.
+func TestListenShutdown(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatal("Shutdown failed:", err)
+ }
+
+ // Wait for the endpoint state to be propagated.
+ time.Sleep(10 * time.Millisecond)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ c.CheckNoPacket("Packet received when listening socket was shutdown")
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -635,6 +1140,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestRstOnSynSent(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
+
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ // Ensure that we've reached SynSent state
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to Error and close out
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -930,8 +1500,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1623,7 +2192,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.DelayOption(1))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 1)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -1671,7 +2240,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.DelayOption(1))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 1)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -1704,7 +2273,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOpt(tcpip.DelayOption(0))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 0)
// Check that data is received.
second := c.GetPacket()
@@ -1741,7 +2310,7 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
{"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
}
@@ -2211,6 +2780,13 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
+
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func TestSendOnResetConnection(t *testing.T) {
@@ -2905,6 +3481,13 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
@@ -2912,12 +3495,12 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2955,10 +3538,9 @@ func TestReadAfterClosedState(t *testing.T) {
),
)
- // Give the stack the chance to transition to closed state. Note that since
- // both the sender and receiver are now closed, we effectively skip the
- // TIME-WAIT state.
- time.Sleep(1 * time.Second)
+ // Give the stack the chance to transition to closed state from
+ // TIME_WAIT.
+ time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
@@ -2975,7 +3557,7 @@ func TestReadAfterClosedState(t *testing.T) {
peekBuf := make([]byte, 10)
n, _, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
- t.Fatalf("Peek failed: %v", err)
+ t.Fatalf("Peek failed: %s", err)
}
peekBuf = peekBuf[:n]
@@ -2986,7 +3568,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -2996,11 +3578,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3773,8 +4355,9 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ const keepAliveInterval = 10 * time.Millisecond
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
@@ -3864,19 +4447,43 @@ func TestKeepalive(t *testing.T) {
)
}
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + 5*time.Millisecond)
+
// The connection should be terminated after 5 unacked keepalives.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ }
+
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
@@ -3890,7 +4497,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
RcvWnd: 30000,
})
- // Receive the SYN-ACK reply.w
+ // Receive the SYN-ACK reply.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss = seqnum.Value(tcp.SequenceNumber())
@@ -3923,6 +4530,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
return irs, iss
}
+func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ // Send a SYN request.
+ irs = seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetV6Packet()
+ tcp := header.TCP(header.IPv6(b).Payload())
+ iss = seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+
+ if synCookieInUse {
+ // When cookies are in use window scaling is disabled.
+ tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
+ WS: -1,
+ MSS: c.MSSWithoutOptionsV6(),
+ }))
+ }
+
+ checker.IPv6(t, b, checker.TCP(tcpCheckers...))
+
+ // Send ACK.
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ return irs, iss
+}
+
// TestListenBacklogFull tests that netstack does not complete handshakes if the
// listen backlog for the endpoint is full.
func TestListenBacklogFull(t *testing.T) {
@@ -4025,6 +4676,210 @@ func TestListenBacklogFull(t *testing.T) {
}
}
+// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
+// non unicast IPv4 address are not accepted.
+func TestListenNoAcceptNonUnicastV4(t *testing.T) {
+ multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
+ otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv4Any,
+ context.StackAddr,
+ },
+ {
+ "SourceBroadcast",
+ header.IPv4Broadcast,
+ context.StackAddr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackAddr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackAddr,
+ },
+ {
+ "DestUnspecified",
+ context.TestAddr,
+ header.IPv4Any,
+ },
+ {
+ "DestBroadcast",
+ context.TestAddr,
+ header.IPv4Broadcast,
+ },
+ {
+ "DestOurMulticast",
+ context.TestAddr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestAddr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestAddr, context.StackAddr)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
+// non unicast IPv6 address are not accepted.
+func TestListenNoAcceptNonUnicastV6(t *testing.T) {
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv6Any,
+ context.StackV6Addr,
+ },
+ {
+ "SourceAllNodes",
+ header.IPv6AllNodesMulticastAddress,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "DestUnspecified",
+ context.TestV6Addr,
+ header.IPv6Any,
+ },
+ {
+ "DestAllNodes",
+ context.TestV6Addr,
+ header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ "DestOurMulticast",
+ context.TestV6Addr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestV6Addr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestV6Addr, context.StackV6Addr)
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
func TestListenSynRcvdQueueFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -4056,7 +4911,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: irs,
RcvWnd: 30000,
})
@@ -4167,7 +5022,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Send a SYN request.
irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
+ // pick a different src port for new SYN.
+ SrcPort: context.TestPort + 1,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
@@ -4207,6 +5063,125 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSynRcvdBadSeqNumber(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcpHdr.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now send a packet with an out-of-window sequence number
+ largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: largeSeqnum,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Should receive an ACK with the expected SEQ number
+ b = c.GetPacket()
+ tcpCheckers = []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.AckNum(uint32(irs) + 1),
+ checker.SeqNum(uint32(iss + 1)),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now that the socket replied appropriately with the ACK,
+ // complete the connection to test that the large SEQ num
+ // did not change the state from SYN-RCVD.
+
+ // Send ACK to move to ESTABLISHED state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ newEP, _, err := c.EP.Accept()
+
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ pkt := c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcpHdr.Payload()) != data {
+ t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ }
+}
+
func TestPassiveConnectionAttemptIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -4381,6 +5356,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
+ t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ }
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
@@ -4668,3 +5646,918 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
payloadSize *= 2
}
}
+
+func TestDelayEnabled(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ checkDelayOption(t, c, false, 0) // Delay is disabled by default.
+
+ for _, v := range []struct {
+ delayEnabled tcp.DelayEnabled
+ wantDelayOption int
+ }{
+ {delayEnabled: false, wantDelayOption: 0},
+ {delayEnabled: true, wantDelayOption: 1},
+ } {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ }
+ checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ }
+}
+
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) {
+ t.Helper()
+
+ var gotDelayEnabled tcp.DelayEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ }
+ if gotDelayEnabled != wantDelayEnabled {
+ t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
+ }
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
+ if err != nil {
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ }
+ gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption)
+ if err != nil {
+ t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err)
+ }
+ if gotDelayOption != wantDelayOption {
+ t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption)
+ }
+}
+
+func TestTCPLingerTimeout(t *testing.T) {
+ c := context.New(t, 1500 /* mtu */)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ testCases := []struct {
+ name string
+ tcpLingerTimeout time.Duration
+ want time.Duration
+ }{
+ {"NegativeLingerTimeout", -123123, 0},
+ {"ZeroLingerTimeout", 0, 0},
+ {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
+ // Values > stack's TCPLingerTimeout are capped to the stack's
+ // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
+ {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
+ t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ }
+ var v tcpip.TCPLingerTimeoutOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ }
+ if got, want := time.Duration(v), tc.want; got != want {
+ t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestTCPTimeWaitRSTIgnored(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now send a RST and this should be ignored and not
+ // generate an ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitOutOfOrder(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitNewSyn(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Send a SYN request w/ sequence number lower than
+ // the highest sequence number sent. We just reuse
+ // the same number.
+ iss = seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+
+ // Send a SYN request w/ sequence number higher than
+ // the highest sequence number sent.
+ iss = seqnum.Value(792)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b = c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+}
+
+func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ time.Sleep(2 * time.Second)
+
+ // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
+ // by another 5 seconds and also send us a duplicate ACK as it should
+ // indicate that the final ACK was potentially lost.
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Sleep for 4 seconds so at this point we are 1 second past the
+ // original tcpLingerTimeout of 5 seconds.
+ time.Sleep(4 * time.Second)
+
+ // Send an ACK and it should not generate any packet as the socket
+ // should still be in TIME_WAIT for another another 5 seconds due
+ // to the duplicate FIN we sent earlier.
+ *ackHeaders = *finHeaders
+ ackHeaders.SeqNum = ackHeaders.SeqNum + 1
+ ackHeaders.Flags = header.TCPFlagAck
+ c.SendPacket(nil, ackHeaders)
+
+ c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
+ // Now sleep for another 2 seconds so that we are past the
+ // extended TIME_WAIT of 7 seconds (2 + 5).
+ time.Sleep(2 * time.Second)
+
+ // Resend the same ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Receive the RST that should be generated as there is no valid
+ // endpoint.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+}
+
+func TestTCPCloseWithData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: 30000,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now trigger a passive close by sending a FIN.
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ RcvWnd: 30000,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now write a few bytes and then close the endpoint.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b = c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+
+ c.EP.Close()
+ // Check the FIN.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.AckNum(uint32(iss+2)),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ // First send a partial ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send a full ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now ACK the FIN.
+ ackHeaders.AckNum++
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send an ACK and we should get a RST back as the endpoint should
+ // be in CLOSED state.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Check the RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ userTimeout := 50 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Send some data and wait before ACKing it.
+ view := buffer.NewView(3)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for a little over the minimum retransmit timeout of 200ms for
+ // the retransmitTimer to fire and close the connection.
+ time.Sleep(tcp.MinRTO + 10*time.Millisecond)
+
+ // No packet should be received as the connection should be silently
+ // closed due to timeout.
+ c.CheckNoPacket("unexpected packet received after userTimeout has expired")
+
+ next += uint32(len(view))
+
+ // The connection should be terminated after userTimeout has expired.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ }
+
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ }
+}
+
+func TestKeepaliveWithUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ const keepAliveInterval = 10 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10))
+ c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+
+ // Set userTimeout to be the duration for 3 keepalive probes.
+ userTimeout := 30 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ // Now receive 2 keepalives, but don't ACK them. The connection should
+ // be reset when the 3rd one should be sent due to userTimeout being
+ // 30ms and each keepalive probe should be sent 10ms apart as set above after
+ // the connection has been idle for 10ms.
+ for i := 0; i < 2; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + 5*time.Millisecond)
+
+ // The connection should be terminated after 30ms.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(c.IRS + 1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ }
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index 19b0d31c5..b33ec2087 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -8,7 +8,7 @@ go_library(
srcs = ["context.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
- "//:sandbox",
+ "//visibility:public",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index ef823e4ae..b0a376eba 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -231,14 +231,15 @@ func (c *Context) CheckNoPacket(errMsg string) {
// addresses. It will fail with an error if no packet is received for
// 2 seconds.
func (c *Context) GetPacket() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
@@ -259,14 +260,15 @@ func (c *Context) GetPacket() []byte {
// and destination address. If no packet is available it will return
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
@@ -302,11 +304,19 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
+ return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
+}
+
+// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
+// payload and source and destination IPv4 addresses.
+func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -319,8 +329,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
+ SrcAddr: src,
+ DstAddr: dst,
})
ip.SetChecksum(^ip.CalculateChecksum())
@@ -337,7 +347,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -350,13 +360,26 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.Inject(ipv4.ProtocolNumber, s)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: s,
+ })
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegment(payload, h),
+ })
+}
+
+// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
+// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
+// provided source and destination IPv4 addresses.
+func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
+ })
}
// SendAck sends an ACK packet.
@@ -462,14 +485,15 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// GetV6Packet reads a single packet from the link layer endpoint of the context
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
@@ -484,6 +508,13 @@ func (c *Context) GetV6Packet() []byte {
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
+}
+
+// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
+// endpoint of the context using the provided source and destination IPv6
+// addresses.
+func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -494,8 +525,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: TestV6Addr,
- DstAddr: StackV6Addr,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
@@ -511,14 +542,16 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// CreateConnected creates a connected TCP endpoint.
@@ -1059,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) {
func (c *Context) MSSWithoutOptions() uint16 {
return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
}
+
+// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
+// options are in use for IPv6 packets.
+func (c *Context) MSSWithoutOptionsV6() uint16 {
+ return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
+}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c9460aa0d..97e4d5825 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
@@ -59,11 +60,3 @@ go_test(
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "udp_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6e87245b7..1ac4705af 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -31,9 +32,6 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
// EndpointState represents the state of a UDP endpoint.
@@ -80,6 +78,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -106,6 +105,11 @@ type endpoint struct {
bindToDevice tcpip.NICID
broadcast bool
+ // Values used to reserve a port or register a transport endpoint.
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+
// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
@@ -140,7 +144,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
+ TransProto: header.UDPProtocolNumber,
},
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
@@ -160,9 +164,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: StateInitial,
+ uniqueID: s.UniqueID(),
}
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -171,8 +181,10 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
for _, mem := range e.multicastMemberships {
@@ -278,7 +290,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
if isBroadcastOrMulticast(localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -286,20 +298,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
}
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicid == 0 {
- nicid = e.multicastNICID
+ if nicID == 0 {
+ nicID = e.multicastNICID
}
- if localAddr == "" && nicid == 0 {
+ if localAddr == "" && nicID == 0 {
localAddr = e.multicastAddr
}
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
return stack.Route{}, 0, err
}
- return r, nicid, nil
+ return r, nicID, nil
}
// Write writes data to the endpoint's peer. This method does not block
@@ -378,13 +390,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicid := to.NIC
+ nicID := to.NIC
if e.BindNICID != 0 {
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
}
if to.Addr == header.IPv4Broadcast && !e.broadcast {
@@ -396,7 +408,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, err
}
- r, _, err := e.connectRoute(nicid, *to, netProto)
+ r, _, err := e.connectRoute(nicID, *to, netProto)
if err != nil {
return 0, nil, err
}
@@ -618,9 +630,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = 0
return nil
}
- for nicid, nic := range e.stack.NICInfo() {
+ for nicID, nic := range e.stack.NICInfo() {
if nic.Name == string(v) {
- e.bindToDevice = nicid
+ e.bindToDevice = nicID
return nil
}
}
@@ -728,6 +740,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.MulticastLoopOption(v)
return nil
+ case *tcpip.ReuseAddressOption:
+ *o = 0
+ return nil
+
case *tcpip.ReusePortOption:
e.mu.RLock()
v := e.reusePort
@@ -809,7 +825,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: data,
+ TransportHeader: buffer.View(udp),
+ }); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
@@ -859,7 +879,10 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if e.state != StateConnected {
return nil
}
- id := stack.TransportEndpointID{}
+ var (
+ id stack.TransportEndpointID
+ btd tcpip.NICID
+ )
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -867,7 +890,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
LocalPort: e.ID.LocalPort,
LocalAddress: e.ID.LocalAddress,
}
- id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
@@ -875,13 +898,15 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.ID = id
+ e.boundBindToDevice = btd
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@@ -903,7 +928,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicid := addr.NIC
+ nicID := addr.NIC
var localPort uint16
switch e.state {
case StateInitial:
@@ -913,16 +938,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
break
}
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
- r, nicid, err := e.connectRoute(nicid, addr, netProto)
+ r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
}
@@ -950,20 +975,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
}
e.ID = id
+ e.boundBindToDevice = btd
e.route = r.Clone()
e.dstPort = addr.Port
- e.RegisterNICID = nicid
+ e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
e.state = StateConnected
@@ -1018,20 +1044,27 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
+ flags := ports.Flags{
+ LoadBalanced: e.reusePort,
+ // FIXME(b/129164367): Support SO_REUSEADDR.
+ MostRecent: false,
+ }
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
if err != nil {
- return id, err
+ return id, e.bindToDevice, err
}
+ e.boundPortFlags = flags
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.boundPortFlags = ports.Flags{}
}
- return id, err
+ return id, e.bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1057,11 +1090,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
}
- nicid := addr.NIC
+ nicID := addr.NIC
if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
// A local unicast address was specified, verify that it's valid.
- nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicid == 0 {
+ nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicID == 0 {
return tcpip.ErrBadLocalAddress
}
}
@@ -1070,13 +1103,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
- e.RegisterNICID = nicid
+ e.boundBindToDevice = btd
+ e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
@@ -1111,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
+ addr := e.ID.LocalAddress
+ if e.state == StateConnected {
+ addr = e.route.LocalAddress
+ }
+
return tcpip.FullAddress{
NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
+ Addr: addr,
Port: e.ID.LocalPort,
}, nil
}
@@ -1154,17 +1193,17 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
// Get the header then trim it from the view.
- hdr := header.UDP(vv.First())
- if int(hdr.Length()) > vv.Size() {
+ hdr := header.UDP(pkt.Data.First())
+ if int(hdr.Length()) > pkt.Data.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- vv.TrimFront(header.UDPMinimumSize)
+ pkt.Data.TrimFront(header.UDPMinimumSize)
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
@@ -1188,18 +1227,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &udpPacket{
+ packet := &udpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
Port: hdr.SourcePort(),
},
}
- pkt.data = vv.Clone(pkt.views[:])
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += vv.Size()
+ packet.data = pkt.Data
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += pkt.Data.Size()
- pkt.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
@@ -1210,7 +1249,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
}
// State implements tcpip.Endpoint.State.
@@ -1234,6 +1273,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements tcpip.Endpoint.Wait.
+func (*endpoint) Wait() {}
+
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index b227e353b..43fb047ed 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -109,7 +109,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
// pass it to the reservation machinery.
id := e.ID
e.ID.LocalPort = 0
- e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index d399ec722..fc706ede2 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -16,7 +16,6 @@ package udp
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
id: id,
- vv: vv,
+ pkt: pkt,
})
return true
@@ -62,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- vv buffer.VectorisedView
+ pkt tcpip.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.rcvReady = true
ep.rcvMu.Unlock()
- ep.HandlePacket(r.route, r.id, r.vv)
+ ep.HandlePacket(r.route, r.id, r.pkt)
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index de026880f..259c3072a 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
// Get the header then trim it from the view.
- hdr := header.UDP(vv.First())
- if int(hdr.Length()) > vv.Size() {
+ hdr := header.UDP(pkt.Data.First())
+ if int(hdr.Length()) > pkt.Data.Size() {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
@@ -116,13 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
- payloadLen := len(netHeader) + vv.Size()
+ payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, a raw or packet socket may use what UDP
+ // considers an unreachable destination. Thus we deep copy pkt
+ // to prevent multiple ownership and SR errors.
+ newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ payload := newNetHeader.ToVectorisedView()
+ payload.Append(pkt.Data.ToView().ToVectorisedView())
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
@@ -130,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
case header.IPv6AddressSize:
if !r.Stack().AllowICMPMessage() {
@@ -151,12 +159,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
available := int(mtu) - headerLen
- payloadLen := len(netHeader) + vv.Size()
+ payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader})
+ payload.Append(pkt.Data)
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
@@ -164,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
}
return true
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index b724d788c..7051a7a9c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -356,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
if p.Proto != flow.netProto() {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
h := flow.header4Tuple(outgoing)
checkers := append(
@@ -397,7 +397,8 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
@@ -431,7 +432,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
}
// injectV4Packet creates a V4 test packet with the given payload and header
@@ -441,7 +446,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
@@ -471,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
}
func newPayload() []byte {
@@ -1442,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) {
select {
case p := <-c.linkEP.C:
var pkt []byte
- pkt = append(pkt, p.Header...)
- pkt = append(pkt, p.Payload...)
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1516,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) {
select {
case p := <-c.linkEP.C:
var pkt []byte
- pkt = append(pkt, p.Header...)
- pkt = append(pkt, p.Payload...)
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}