summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go97
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go43
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go259
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go232
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go38
-rw-r--r--pkg/tcpip/transport/tcp/BUILD21
-rw-r--r--pkg/tcpip/transport/tcp/accept.go188
-rw-r--r--pkg/tcpip/transport/tcp/connect.go143
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go150
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go57
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go719
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go106
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go238
-rw-r--r--pkg/tcpip/transport/tcp/rack.go124
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go127
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment.go91
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go52
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/snd.go157
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go32
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go137
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go27
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go2375
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go29
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go208
-rw-r--r--pkg/tcpip/transport/tcp/timer.go1
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go5
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go349
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go8
-rw-r--r--pkg/tcpip/transport/udp/protocol.go173
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go945
39 files changed, 4885 insertions, 2372 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9ce625c17..7e5c79776 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b1d820372..a17234946 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -73,6 +74,8 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -110,7 +113,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -140,11 +143,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -347,7 +345,16 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
+ }
return nil
}
@@ -371,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
default:
@@ -415,9 +422,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
return nil
default:
@@ -430,9 +440,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ })
+ pkt.Owner = owner
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
copy(icmpv4, data)
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
@@ -447,15 +461,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+ pkt.Data = data.ToVectorisedView()
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: data.ToVectorisedView(),
- TransportHeader: buffer.View(icmpv4),
- Owner: owner,
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -463,9 +474,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ })
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(icmpv6, data)
// Set the ident. Sequence number is provided by the user.
icmpv6.SetIdent(ident)
@@ -477,15 +491,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
dataVV := data.ToVectorisedView()
icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
+ pkt.Data = dataVV
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: dataVV,
- TransportHeader: buffer.View(icmpv6),
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
// checkV4MappedLocked determines the effective network protocol and converts
@@ -511,6 +522,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
@@ -603,7 +615,7 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -611,14 +623,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -743,19 +755,23 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -789,12 +805,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
},
}
- packet.data = pkt.Data
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader().View().ToVectorisedView()
+ packet.data.Append(pkt.Data)
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -805,7 +823,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
@@ -830,3 +848,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+
+// LastError implements tcpip.Endpoint.LastError.
+func (*endpoint) LastError() *tcpip.Error {
+ return nil
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 3c47692b2..87d510f96 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -13,12 +13,7 @@
// limitations under the License.
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
-// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and activated on the stack by passing
-// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
-// when calling Stack.NewEndpoint().
+// protocols for use in ping.
package icmp
import (
@@ -42,6 +37,8 @@ const (
// protocol implements stack.TransportProtocol.
type protocol struct {
+ stack *stack.Stack
+
number tcpip.TransportProtocolNumber
}
@@ -62,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue)
+ return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+ return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
@@ -104,17 +101,17 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
- return true
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -124,12 +121,22 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
+ //
+ // Right now, the Parse() method is tied to enabled protocols passed into
+ // stack.New. This works for UDP and TCP, but we handle ICMP traffic even
+ // when netstack users don't pass ICMP as a supported protocol.
+ return false
+}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
-func NewProtocol4() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
+func NewProtocol4(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber4}
}
// NewProtocol6 returns an ICMPv6 transport protocol.
-func NewProtocol6() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
+func NewProtocol6(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 23158173d..31831a6d8 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,6 +25,8 @@
package packet
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -43,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -71,11 +76,19 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -92,6 +105,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
sndBufSize: 32 * 1024,
}
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
return nil, err
}
@@ -132,13 +156,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (stack.IPTables, error) {
- return ep.stack.IPTables(), nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -163,16 +182,25 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
-func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- // TODO(b/129292371): Implement.
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
+func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -184,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
// connected, and this function always returnes tcpip.ErrNotSupported.
-func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
// with Shutdown, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
// Listen, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
// Accept, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -220,12 +248,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -233,17 +263,18 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -268,8 +299,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
-func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ ep.linger = *v
+ ep.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
@@ -279,26 +322,113 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (ep *endpoint) LastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ *o = ep.linger
+ ep.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrNotSupported
+ }
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrNotSupported
+func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.AcceptConnOption:
+ return false, nil
+ default:
+ return false, tcpip.ErrNotSupported
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -320,48 +450,73 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(b/129292371): Return network protocol.
- if len(pkt.LinkHeader) > 0 {
+ // TODO(gvisor.dev/issue/173): Return network protocol.
+ if !pkt.LinkHeader().View().IsEmpty() {
// Get info directly from the ethernet header.
- hdr := header.Ethernet(pkt.LinkHeader)
+ hdr := header.Ethernet(pkt.LinkHeader().View())
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header.
+ var combinedVV buffer.VectorisedView
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if pkt.LinkHeader().View().IsEmpty() {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
} else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data)
- packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
@@ -375,7 +530,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
// State implements socket.Socket.State.
-func (ep *endpoint) State() uint32 {
+func (*endpoint) State() uint32 {
return 0
}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eee754a5a..79f688129 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,8 @@
package raw
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -61,25 +63,29 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
rcvList rawPacketList
- rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -91,7 +97,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -103,8 +109,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
}
// Unassociated endpoints are write-only and users call Write() with IP
@@ -166,17 +184,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -206,6 +215,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -249,7 +263,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -310,12 +324,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -345,28 +353,26 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
- break
+ if e.hdrIncluded {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
+ return 0, nil, err
}
-
- hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- Owner: e.owner,
- }); err != nil {
+ } else {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ pkt.Owner = e.owner
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: e.TransProto,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -391,11 +397,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -447,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen implements tcpip.Endpoint.Listen.
-func (e *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -461,14 +462,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -489,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -517,24 +512,85 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
return nil
default:
@@ -545,9 +601,15 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -568,7 +630,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -584,11 +646,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -632,15 +701,26 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
},
}
- networkHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- combinedVV := networkHeader.ToVectorisedView()
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
-
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
@@ -670,3 +750,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+
+func (*endpoint) LastError() *tcpip.Error {
+ return nil
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 33bfb56cd..7d97cbdc7 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
}
// beforeSave is invoked by stateify.
-func (ep *endpoint) beforeSave() {
+func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
// The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
// packets.
- ep.rcvMu.Lock()
+ e.rcvMu.Lock()
}
// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
// Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
+ e.rcvBufSizeMax = 0
// Release the lock acquired in beforeSave() so regular endpoint closing
// logic can proceed after save.
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
return max
}
// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
}
// afterLoad is invoked by stateify.
-func (ep *endpoint) afterLoad() {
- stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+func (e *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// Resume implements tcpip.ResumableEndpoint.Resume.
-func (ep *endpoint) Resume(s *stack.Stack) {
- ep.stack = s
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
// If the endpoint is connected, re-connect.
- if ep.connected {
+ if e.connected {
var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
if err != nil {
panic(err)
}
}
// If the endpoint is bound, re-bind.
- if ep.bound {
- if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
+ if e.bound {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- if ep.associated {
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ if e.associated {
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f38eb6833..518449602 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -40,6 +40,8 @@ go_library(
"endpoint_state.go",
"forwarder.go",
"protocol.go",
+ "rack.go",
+ "rack_state.go",
"rcv.go",
"rcv_state.go",
"reno.go",
@@ -49,6 +51,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
@@ -66,6 +69,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
@@ -76,22 +80,21 @@ go_library(
)
go_test(
- name = "tcp_test",
+ name = "tcp_x_test",
size = "medium",
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
+ "tcp_rack_test.go",
"tcp_sack_test.go",
"tcp_test.go",
"tcp_timestamp_test.go",
],
- # FIXME(b/68809571)
- tags = [
- "flaky",
- ],
+ shard_count = 10,
deps = [
":tcp",
+ "//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -119,3 +122,11 @@ go_test(
"//pkg/tcpip/seqnum",
],
)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e6a23c978..6b3238d6b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -198,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments. The endpoint is returned
-// with n.mu held.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -213,7 +212,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
- n.amss = mssForRoute(&n.route)
+ n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
@@ -221,32 +220,12 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
- // Create sender and receiver.
- //
- // The receiver at least temporarily has a zero receive window scale,
- // but the caller may change it (before starting the protocol loop).
- n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- // Lock the endpoint before registering to ensure that no out of
- // band changes are possible due to incoming packets etc till
- // the endpoint is done initializing.
- n.mu.Lock()
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
- n.mu.Unlock()
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
- return n, nil
+ return n
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -257,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
- if err != nil {
- return nil, err
- }
+ ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ ep.mu.Lock()
ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
@@ -268,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
+
l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
@@ -288,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// to the newly created endpoint.
l.listenEP.propagateInheritableOptionsLocked(ep)
+ if !ep.reserveTupleLocked() {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
+ // Register new endpoint so that packets are routed to it.
+ if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
+
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+
+ ep.isRegistered = true
+
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ ep.notifyAborted()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -378,6 +377,44 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
// Precondition: e.mu and n.mu must be held.
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
+ n.portFlags = e.portFlags
+ n.boundBindToDevice = e.boundBindToDevice
+ n.boundPortFlags = e.boundPortFlags
+ n.userMSS = e.userMSS
+}
+
+// reserveTupleLocked reserves an accepted endpoint's tuple.
+//
+// Preconditions:
+// * propagateInheritableOptionsLocked has been called.
+// * e.mu is held.
+func (e *endpoint) reserveTupleLocked() bool {
+ dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ if !e.stack.ReserveTuple(
+ e.effectiveNetProtos,
+ ProtocolNumber,
+ e.ID.LocalAddress,
+ e.ID.LocalPort,
+ e.boundPortFlags,
+ e.boundBindToDevice,
+ dest,
+ ) {
+ return false
+ }
+
+ e.isPortReserved = true
+ e.boundDest = dest
+ return true
+}
+
+// notifyAborted wakes up any waiters on registered, but not accepted
+// endpoints.
+//
+// This is strictly not required normally as a socket that was never accepted
+// can't really have any registered waiters except when stack.Wait() is called
+// which waits for all registered endpoints to stop and expects an EventHUp.
+func (e *endpoint) notifyAborted() {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -388,20 +425,17 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer ctx.synRcvdCount.dec()
- defer func() {
- e.mu.Lock()
- e.decSynRcvdCount()
- e.mu.Unlock()
- }()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
+ e.decSynRcvdCount()
return
}
ctx.removePendingEndpoint(n)
+ e.decSynRcvdCount()
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
@@ -419,7 +453,9 @@ func (e *endpoint) incSynRcvdCount() bool {
}
func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
e.synRcvdCount--
+ e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
@@ -445,9 +481,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- // TODO(b/143300739): Use the userMSS of the listening socket
- // for accepted sockets.
-
switch {
case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
@@ -478,16 +511,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN without window scaling because we currently
- // dont't encode this information in the cookie.
+ // don't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(timeStampOffset()),
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: mssForRoute(&s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, s.route),
}
e.sendSynTCP(&s.route, tcpFields{
id: s.id,
@@ -534,6 +570,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ iss := s.ackNumber - 1
+ irs := s.sequenceNumber - 1
+
// Since SYN cookies are in use this is potentially an ACK to a
// SYN-ACK we sent but don't have a half open connection state
// as cookies are being used to protect against a potential SYN
@@ -544,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// when under a potential syn flood attack.
//
// Validate the cookie.
- data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -569,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
- if err != nil {
+ n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+
+ n.mu.Lock()
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
+ if !n.reserveTupleLocked() {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
- // Propagate any inheritable options from the listening endpoint
- // to the newly created endpoint.
- e.propagateInheritableOptionsLocked(n)
+ n.isRegistered = true
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
@@ -587,10 +644,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
- // We do not use transitionToStateEstablishedLocked here as there is
- // no handshake state available when doing a SYN cookie based accept.
n.isConnectNotified = true
- n.setEndpointState(StateEstablished)
+ n.transitionToStateEstablishedLocked(&handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ })
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index e4a06c9e1..0aaef495d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.LastError()
+ }
}
// Wait for notification.
@@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error {
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, func() {
- resendWaker.Assert()
- })
+ rt := time.AfterFunc(timeOut, resendWaker.Assert)
defer rt.Stop()
// Set up the wakers.
@@ -521,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error {
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
- var sackEnabled SACKEnabled
+ var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
// status then just default to switching off SACK negotiation.
@@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.LastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -742,11 +746,8 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
optLen := len(tf.opts)
- hdr := &pkt.Header
- packetSize := pkt.Data.Size()
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
- pkt.TransportHeader = buffer.View(tcp)
+ tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
tcp.Encode(&header.TCPFields{
SrcPort: tf.id.LocalPort,
DstPort: tf.id.RemotePort,
@@ -758,8 +759,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
})
copy(tcp[header.TCPMinimumSize:], tf.opts)
- length := uint16(hdr.UsedLength() + packetSize)
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size()))
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
// This is called CHECKSUM_PARTIAL in the Linux kernel. We
@@ -797,17 +797,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
packetSize = size
}
size -= packetSize
- var pkt stack.PacketBuffer
- pkt.Header = buffer.NewPrependable(hdrSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrSize,
+ })
pkt.Hash = tf.txHash
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ pkt.NetworkProtocolNumber = r.NetProto
data.ReadToVV(&pkt.Data, packetSize)
- buildTCPHdr(r, tf, &pkt, gso)
+ buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
- pkts.PushBack(&pkt)
+ pkts.PushBack(pkt)
}
if tf.ttl == 0 {
@@ -833,13 +834,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
return sendTCPBatch(r, tf, data, gso, owner)
}
- pkt := stack.PacketBuffer{
- Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
- Data: data,
- Hash: tf.txHash,
- Owner: owner,
- }
- buildTCPHdr(r, tf, &pkt, gso)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen,
+ Data: data,
+ })
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ buildTCPHdr(r, tf, pkt, gso)
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
@@ -897,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -995,24 +996,21 @@ func (e *endpoint) completeWorkerLocked() {
// transitionToStateEstablisedLocked transitions a given endpoint
// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver if required.
+// It also initializes sender/receiver.
func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- if e.snd == nil {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- }
- if e.rcv == nil {
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvAutoParams.prevCopied = int(h.rcvWnd)
- e.rcvListMu.Unlock()
- }
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+
e.setEndpointState(StateEstablished)
}
@@ -1022,14 +1020,19 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.EndpointState() == StateClose {
+ s := e.EndpointState()
+ if s == StateClose {
return
}
+
+ if s.connected() {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+ }
+
// Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
e.setEndpointState(StateClose)
- e.stack.Stats().TCP.CurrentConnected.Decrement()
- e.stack.Stats().TCP.EstablishedClosed.Increment()
}
// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
@@ -1052,8 +1055,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
panic("current endpoint not removed from demuxer, enqueing segments to itself")
}
- if ep.(*endpoint).enqueueSegment(s) {
- ep.(*endpoint).newSegmentWaker.Assert()
+ if ep := ep.(*endpoint); ep.enqueueSegment(s) {
+ ep.newSegmentWaker.Assert()
}
}
@@ -1122,7 +1125,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ if e.EndpointState().closed() {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1132,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
}
cont, err := e.handleSegment(s)
+ s.decRef()
if err != nil {
- s.decRef()
return err
}
if !cont {
- s.decRef()
return nil
}
}
@@ -1159,13 +1161,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
return nil
}
-// handleSegment handles a given segment and notifies the worker goroutine if
-// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
- // Invoke the tcp probe if installed.
+func (e *endpoint) probeSegment() {
if e.probe != nil {
e.probe(e.completeState())
}
+}
+
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed. The tcp probe function will update
+ // the TCPEndpointState after the segment is processed.
+ defer e.probeSegment()
if s.flagIsSet(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
@@ -1224,7 +1231,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
- s.decRef()
return false, nil
}
@@ -1416,10 +1422,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.rcv.nonZeroWindow()
}
- if n&notifyReceiveWindowChanged != 0 {
- e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
- }
-
if n&notifyMTUChanged != 0 {
e.sndBufMu.Lock()
count := e.packetTooBigCount
@@ -1442,9 +1444,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
- closeWaker.Assert()
- })
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
}
@@ -1462,7 +1462,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return err
}
}
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ if !e.EndpointState().closed() {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1518,6 +1518,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1527,7 +1528,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
loop:
- for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ for {
+ switch e.EndpointState() {
+ case StateTimeWait, StateClose, StateError:
+ break loop
+ }
+
e.mu.Unlock()
v, _ := s.Fetch(true)
e.mu.Lock()
@@ -1570,11 +1576,14 @@ loop:
reuseTW = e.doTimeWait()
}
- // Mark endpoint as closed.
- if e.EndpointState() != StateError {
- e.transitionToStateCloseLocked()
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
}
+ e.transitionToStateCloseLocked()
+
// Lock released below.
epilogue()
@@ -1687,7 +1696,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 || n&notifyAbort != 0 {
+ if n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 6062ca916..98aecab9e 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -15,6 +15,8 @@
package tcp
import (
+ "encoding/binary"
+
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -66,89 +68,68 @@ func (q *epQueue) empty() bool {
// processor is responsible for processing packets queued to a tcp endpoint.
type processor struct {
epQ epQueue
+ sleeper sleep.Sleeper
newEndpointWaker sleep.Waker
closeWaker sleep.Waker
- id int
- wg sync.WaitGroup
-}
-
-func newProcessor(id int) *processor {
- p := &processor{
- id: id,
- }
- p.wg.Add(1)
- go p.handleSegments()
- return p
}
func (p *processor) close() {
p.closeWaker.Assert()
}
-func (p *processor) wait() {
- p.wg.Wait()
-}
-
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
p.newEndpointWaker.Assert()
}
-func (p *processor) handleSegments() {
- const newEndpointWaker = 1
- const closeWaker = 2
- s := sleep.Sleeper{}
- s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
- s.AddWaker(&p.closeWaker, closeWaker)
- defer s.Done()
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
for {
- id, ok := s.Fetch(true)
- if ok && id == closeWaker {
- p.wg.Done()
- return
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
}
- for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
if ep.segmentQueue.empty() {
continue
}
- // If socket has transitioned out of connected state
- // then just let the worker handle the packet.
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
//
- // NOTE: We read this outside of e.mu lock which means
- // that by the time we get to handleSegments the
- // endpoint may not be in ESTABLISHED. But this should
- // be fine as all normal shutdown states are handled by
- // handleSegments and if the endpoint moves to a
- // CLOSED/ERROR state then handleSegments is a noop.
- if ep.EndpointState() != StateEstablished {
- ep.newSegmentWaker.Assert()
- continue
- }
-
- if !ep.mu.TryLock() {
- ep.newSegmentWaker.Assert()
- continue
- }
- // If the endpoint is in a connected state then we do
- // direct delivery to ensure low latency and avoid
- // scheduler interactions.
- if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
- // Send any active resets if required.
- if err != nil {
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
}
- ep.notifyProtocolGoroutine(notifyTickleWorker)
ep.mu.Unlock()
- continue
- }
-
- if !ep.segmentQueue.empty() {
- p.epQ.enqueue(ep)
+ } else {
+ ep.newSegmentWaker.Assert()
}
-
- ep.mu.Unlock()
}
}
}
@@ -159,34 +140,39 @@ func (p *processor) handleSegments() {
// hash of the endpoint id to ensure that delivery for the same endpoint happens
// in-order.
type dispatcher struct {
- processors []*processor
+ processors []processor
seed uint32
-}
-
-func newDispatcher(nProcessors int) *dispatcher {
- processors := []*processor{}
- for i := 0; i < nProcessors; i++ {
- processors = append(processors, newProcessor(i))
- }
- return &dispatcher{
- processors: processors,
- seed: generateRandUint32(),
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
}
}
func (d *dispatcher) close() {
- for _, p := range d.processors {
- p.close()
+ for i := range d.processors {
+ d.processors[i].close()
}
}
func (d *dispatcher) wait() {
- for _, p := range d.processors {
- p.wait()
- }
+ d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
if !s.parse() {
@@ -231,20 +217,18 @@ func generateRandUint32() uint32 {
if _, err := rand.Read(b); err != nil {
panic(err)
}
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return binary.LittleEndian.Uint32(b)
}
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
- payload := []byte{
- byte(id.LocalPort),
- byte(id.LocalPort >> 8),
- byte(id.RemotePort),
- byte(id.RemotePort >> 8)}
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
h := jenkins.Sum32(d.seed)
- h.Write(payload)
+ h.Write(payload[:])
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
- return d.processors[h.Sum32()%uint32(len(d.processors))]
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 804e95aea..560b4904c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -78,16 +78,15 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv4(t, c.GetPacket(), ackCheckers...)
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -186,16 +185,15 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -285,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -321,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -354,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -373,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -494,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -512,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept()
+ var addr tcpip.FullAddress
+ nep, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -528,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) {
}
}
+ if addr.Addr != context.TestV6Addr {
+ t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
+ }
+
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Fatalf("GetSockOpt failed failed: %v", err)
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
+ t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
}
}
@@ -568,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) {
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
const n = uint16(32)
@@ -612,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5ba972f1..c826942e9 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,7 +63,19 @@ const (
StateClosing
)
-// connected is the set of states where an endpoint is connected to a peer.
+const (
+ // rcvAdvWndScale is used to split the available socket buffer into
+ // application buffer and the window to be advertised to the peer. This is
+ // currently hard coded to split the available space equally.
+ rcvAdvWndScale = 1
+
+ // SegOverheadFactor is used to multiply the value provided by the
+ // user on a SetSockOpt for setting the socket send/receive buffer sizes.
+ SegOverheadFactor = 2
+)
+
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
func (s EndpointState) connected() bool {
switch s {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
@@ -73,6 +85,40 @@ func (s EndpointState) connected() bool {
}
}
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// String implements fmt.Stringer.String.
func (s EndpointState) String() string {
switch s {
@@ -114,7 +160,6 @@ func (s EndpointState) String() string {
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
- notifyReceiveWindowChanged
notifyClose
notifyMTUChanged
notifyDrain
@@ -203,6 +248,11 @@ type ReceiveErrors struct {
// ZeroRcvWindowState is the number of times we advertised
// a zero receive window when rcvList is full.
ZeroRcvWindowState tcpip.StatCounter
+
+ // WantZeroWindow is the number of times we wanted to advertise a
+ // zero receive window but couldn't because it would have caused
+ // the receive window's right edge to shrink.
+ WantZeroRcvWindow tcpip.StatCounter
}
// SendErrors collect segment send errors within the transport layer.
@@ -349,19 +399,33 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- rcvBufSize int
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ // rcvBufSize is the total size of the receive buffer.
+ rcvBufSize int
+ // rcvBufUsed is the actual number of payload bytes held in the receive buffer
+ // not counting any overheads of the segments itself. NOTE: This will always
+ // be strictly <= rcvMemUsed below.
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
+ // rcvMemUsed tracks the total amount of memory in use by received segments
+ // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // compute the window and the actual available buffer space. This is distinct
+ // from rcvBufUsed above which is the actual number of payload bytes held in
+ // the buffer not including any segment overheads.
+ //
+ // rcvMemUsed must be accessed atomically.
+ rcvMemUsed int32
+
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
mu sync.Mutex `state:"nosave"`
ownedByUser uint32
- // state must be read/set using the EndpointState()/setEndpointState() methods.
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -370,8 +434,8 @@ type endpoint struct {
origEndpointState EndpointState `state:"nosave"`
isPortReserved bool `state:"manual"`
- isRegistered bool
- boundNICID tcpip.NICID `state:"manual"`
+ isRegistered bool `state:"manual"`
+ boundNICID tcpip.NICID
route stack.Route `state:"manual"`
ttl uint8
v6only bool
@@ -380,10 +444,14 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // portFlags stores the current values of port related flags.
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
+ boundDest tcpip.FullAddress
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -391,7 +459,7 @@ type endpoint struct {
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
- effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -409,10 +477,11 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
- //
- // recentTS must be read/written atomically.
recentTS uint32
+ // recentTSTime is the unix time when we updated recentTS last.
+ recentTSTime time.Time `state:".(unixTime)"`
+
// tsOffset is a randomized offset added to the value of the
// TSVal field in the timestamp option.
tsOffset uint32
@@ -427,9 +496,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // reusePort is set to true if SO_REUSEPORT is enabled.
- reusePort bool
-
// bindToDevice is set to the NIC on which to bind or disabled if 0.
bindToDevice tcpip.NICID
@@ -449,7 +515,6 @@ type endpoint struct {
// The options below aren't implemented, but we remember the user
// settings because applications expect to be able to set/query these
// options.
- reuseAddr bool
// slowAck holds the negated state of quick ack. It is stubbed out and
// does nothing.
@@ -617,6 +682,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -630,7 +698,8 @@ func (e *endpoint) UniqueID() uint64 {
// r, it will be used; otherwise, the maximum possible MSS will be used.
func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
- maxMSS := mssForRoute(&r)
+ // TODO(b/143359391): Respect TCP Min and Max size.
+ maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
if userMSS != 0 && userMSS < maxMSS {
return userMSS
@@ -759,15 +828,15 @@ func (e *endpoint) EndpointState() EndpointState {
return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
}
-// setRecentTimestamp atomically sets the recentTS field to the
-// provided value.
+// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- atomic.StoreUint32(&e.recentTS, recentTS)
+ e.recentTS = recentTS
+ e.recentTSTime = time.Now()
}
-// recentTimestamp atomically reads and returns the value of the recentTS field.
+// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return atomic.LoadUint32(&e.recentTS)
+ return e.recentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -799,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSize: DefaultReceiveBufferSize,
sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
- reuseAddr: true,
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -812,12 +880,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
maxSynRetries: DefaultSynRetries,
}
- var ss SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
@@ -827,12 +895,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.cc = cs
}
- var mrb tcpip.ModerateReceiveBufferOption
+ var mrb tcpip.TCPModerateReceiveBufferOption
if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
e.rcvAutoParams.disabled = !bool(mrb)
}
- var de DelayEnabled
+ var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
e.SetSockOptBool(tcpip.DelayOption, true)
}
@@ -851,7 +919,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.probe = p
}
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
@@ -864,10 +932,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
switch e.EndpointState() {
- case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
+ case StateInitial, StateBound:
+ // This prevents blocking of new sockets which are not
+ // connected when SO_LINGER is set.
+ result |= waiter.EventHUp
+
+ case StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case StateClose, StateError:
+ case StateClose, StateError, StateTimeWait:
// Ready for anything.
result = mask
@@ -970,6 +1043,26 @@ func (e *endpoint) Close() {
return
}
+ if e.linger.Enabled && e.linger.Timeout == 0 {
+ s := e.EndpointState()
+ isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
+ if isResetState {
+ // Close the endpoint without doing full shutdown and
+ // send a RST.
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.closeNoShutdownLocked()
+
+ // Wake up worker to close the endpoint.
+ switch s {
+ case StateSynRecv:
+ e.notifyProtocolGoroutine(notifyClose)
+ default:
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ return
+ }
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -986,14 +1079,15 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
}
// Mark endpoint as closed.
@@ -1014,6 +1108,8 @@ func (e *endpoint) closeNoShutdownLocked() {
e.notifyProtocolGoroutine(notifyClose)
} else {
e.transitionToStateCloseLocked()
+ // Notify that the endpoint is closed.
+ e.waiterQueue.Notify(waiter.EventHUp)
}
}
@@ -1051,26 +1147,33 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
}
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
e.route.Release()
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
+// wndFromSpace returns the window that we can advertise based on the available
+// receive buffer space.
+func wndFromSpace(space int) int {
+ return space >> rcvAdvWndScale
+}
+
// initialReceiveWindow returns the initial receive window to advertise in the
// SYN/SYN-ACK.
func (e *endpoint) initialReceiveWindow() int {
- rcvWnd := e.receiveBufferAvailable()
+ rcvWnd := wndFromSpace(e.receiveBufferAvailable())
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
@@ -1147,14 +1250,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = rcvWnd
- availAfter := e.receiveBufferAvailableLocked()
- mask := uint32(notifyReceiveWindowChanged)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
- e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -1172,14 +1273,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
+func (e *endpoint) LastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
+ defer e.UnlockUser()
+
+ // When in SYN-SENT state, let the caller block on the receive.
+ // An application can initiate a non-blocking connect and then block
+ // on a receive. It can expect to read any data after the handshake
+ // is complete. RFC793, section 3.9, p58.
+ if e.EndpointState() == StateSynSent {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1189,7 +1303,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1199,7 +1312,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
- e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1220,18 +1332,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
v := views[s.viewToDeliver]
s.viewToDeliver++
+ var delta int
if s.viewToDeliver >= len(views) {
e.rcvList.Remove(s)
+ // We only free up receive buffer space when the segment is released as the
+ // segment is still holding on to the views even though some views have been
+ // read out to the user.
+ delta = s.segMemSize()
s.decRef()
}
e.rcvBufUsed -= len(v)
-
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1244,14 +1360,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
- // The endpoint cannot be written to if it's not connected.
- if !e.EndpointState().connected() {
- switch e.EndpointState() {
- case StateError:
- return 0, e.HardError
- default:
- return 0, tcpip.ErrClosedForSend
- }
+ switch s := e.EndpointState(); {
+ case s == StateError:
+ return 0, e.HardError
+ case !s.connecting() && !s.connected():
+ return 0, tcpip.ErrClosedForSend
+ case s.connecting():
+ // As per RFC793, page 56, a send request arriving when in connecting
+ // state, can be queued to be completed after the state becomes
+ // connected. Return an error code for the caller of endpoint Write to
+ // try again, until the connection handshake is complete.
+ return 0, tcpip.ErrWouldBlock
}
// Check if the connection has already been closed for sends.
@@ -1404,12 +1523,44 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
+// selectWindowLocked returns the new window without checking for shrinking or scaling
+// applied.
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
+ maxWindow := wndFromSpace(e.rcvBufSize)
+ wndFromUsedBytes := maxWindow - e.rcvBufUsed
+
+ // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
+ // cases where we receive a lot of small segments the segment overhead is a
+ // lot higher and we can run out socket buffer space before we can fill the
+ // previous window we advertised. In cases where we receive MSS sized or close
+ // MSS sized segments we will probably run out of window space before we
+ // exhaust receive buffer.
+ newWnd := wndFromAvailable
+ if newWnd > wndFromUsedBytes {
+ newWnd = wndFromUsedBytes
+ }
+ if newWnd < 0 {
+ newWnd = 0
+ }
+ return seqnum.Size(newWnd)
+}
+
+// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu.
+func (e *endpoint) selectWindow() (wnd seqnum.Size) {
+ e.rcvListMu.Lock()
+ wnd = e.selectWindowLocked()
+ e.rcvListMu.Unlock()
+ return wnd
+}
+
// windowCrossedACKThresholdLocked checks if the receive window to be announced
-// now would be under aMSS or under half receive buffer, whichever smaller. This
-// is useful as a receive side silly window syndrome prevention mechanism. If
-// window grows to reasonable value, we should send ACK to the sender to inform
-// the rx space is now large. We also want ensure a series of small read()'s
-// won't trigger a flood of spurious tiny ACK's.
+// would be under aMSS or under the window derived from half receive buffer,
+// whichever smaller. This is useful as a receive side silly window syndrome
+// prevention mechanism. If window grows to reasonable value, we should send ACK
+// to the sender to inform the rx space is now large. We also want ensure a
+// series of small read()'s won't trigger a flood of spurious tiny ACK's.
//
// For large receive buffers, the threshold is aMSS - once reader reads more
// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
@@ -1420,17 +1571,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := e.receiveBufferAvailableLocked()
+ newAvail := int(e.selectWindowLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
}
-
threshold := int(e.amss)
- if threshold > e.rcvBufSize/2 {
- threshold = e.rcvBufSize / 2
+ // rcvBufFraction is the inverse of the fraction of receive buffer size that
+ // is used to decide if the available buffer space is now above it.
+ const rcvBufFraction = 2
+ if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ threshold = wndThreshold
}
-
switch {
case oldAvail < threshold && newAvail >= threshold:
return true, true
@@ -1486,12 +1638,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
case tcpip.ReuseAddressOption:
e.LockUser()
- e.reuseAddr = v
+ e.portFlags.TupleOnly = v
e.UnlockUser()
case tcpip.ReusePortOption:
e.LockUser()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.UnlockUser()
case tcpip.V6OnlyOption:
@@ -1549,21 +1701,34 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
- var rs ReceiveBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
+ }
+
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < rs.Min {
v = rs.Min
}
- if v > rs.Max {
- v = rs.Max
- }
+ } else {
+ v = math.MaxInt32
}
- mask := uint32(notifyReceiveWindowChanged)
-
e.LockUser()
e.rcvListMu.Lock()
@@ -1577,14 +1742,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
v = 1 << scale
}
- // Make sure 2*size doesn't overflow.
- if v > math.MaxInt32/2 {
- v = math.MaxInt32 / 2
- }
-
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = v
- availAfter := e.receiveBufferAvailableLocked()
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvAutoParams.disabled = true
@@ -1592,24 +1752,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
e.rcvListMu.Unlock()
e.UnlockUser()
- e.notifyProtocolGoroutine(mask)
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- var ss SendBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ var ss tcpip.TCPSendBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
+ }
+
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < ss.Min {
v = ss.Min
}
- if v > ss.Max {
- v = ss.Max
- }
+ } else {
+ v = math.MaxInt32
}
e.sndBufMu.Lock()
@@ -1642,7 +1809,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if v < rs.Min/2 {
v = rs.Min / 2
@@ -1656,10 +1823,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case tcpip.BindToDeviceOption:
- id := tcpip.NICID(v)
+ case *tcpip.BindToDeviceOption:
+ id := tcpip.NICID(*v)
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
@@ -1667,40 +1834,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = id
e.UnlockUser()
- case tcpip.KeepaliveIdleOption:
+ case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
- e.keepalive.idle = time.Duration(v)
+ e.keepalive.idle = time.Duration(*v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case tcpip.KeepaliveIntervalOption:
+ case *tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
- e.keepalive.interval = time.Duration(v)
+ e.keepalive.interval = time.Duration(*v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case tcpip.OutOfBandInlineOption:
+ case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
- case tcpip.TCPUserTimeoutOption:
+ case *tcpip.TCPUserTimeoutOption:
e.LockUser()
- e.userTimeout = time.Duration(v)
+ e.userTimeout = time.Duration(*v)
e.UnlockUser()
- case tcpip.CongestionControlOption:
+ case *tcpip.CongestionControlOption:
// Query the available cc algorithms in the stack and
// validate that the specified algorithm is actually
// supported in the stack.
- var avail tcpip.AvailableCongestionControlOption
+ var avail tcpip.TCPAvailableCongestionControlOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
return err
}
availCC := strings.Split(string(avail), " ")
for _, cc := range availCC {
- if v == tcpip.CongestionControlOption(cc) {
+ if *v == tcpip.CongestionControlOption(cc) {
e.LockUser()
state := e.EndpointState()
- e.cc = v
+ e.cc = *v
switch state {
case StateEstablished:
if e.EndpointState() == state {
@@ -1716,33 +1883,43 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// control algorithm is specified.
return tcpip.ErrNoSuchFile
- case tcpip.TCPLingerTimeoutOption:
+ case *tcpip.TCPLingerTimeoutOption:
e.LockUser()
- if v < 0 {
+
+ switch {
+ case *v < 0:
// Same as effectively disabling TCPLinger timeout.
- v = 0
- }
- var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
- if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
- // We were unable to retrieve a stack config, just use
- // the DefaultTCPLingerTimeout.
- if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
- stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ *v = -1
+ case *v == 0:
+ // Same as the stack default.
+ var stackLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err))
}
+ *v = stackLingerTimeout
+ case *v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout):
+ // Cap it to Stack's default TCP_LINGER2 timeout.
+ *v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
+ default:
}
- // Cap it to the stack wide TCPLinger timeout.
- if v > stkTCPLingerTimeout {
- v = stkTCPLingerTimeout
- }
- e.tcpLingerTimeout = time.Duration(v)
+
+ e.tcpLingerTimeout = time.Duration(*v)
e.UnlockUser()
- case tcpip.TCPDeferAcceptOption:
+ case *tcpip.TCPDeferAcceptOption:
e.LockUser()
- if time.Duration(v) > MaxRTO {
- v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ if time.Duration(*v) > MaxRTO {
+ *v = tcpip.TCPDeferAcceptOption(MaxRTO)
}
- e.deferAccept = time.Duration(v)
+ e.deferAccept = time.Duration(*v)
+ e.UnlockUser()
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.LockUser()
+ e.linger = *v
e.UnlockUser()
default:
@@ -1795,14 +1972,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.ReuseAddressOption:
e.LockUser()
- v := e.reuseAddr
+ v := e.portFlags.TupleOnly
e.UnlockUser()
return v, nil
case tcpip.ReusePortOption:
e.LockUser()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.UnlockUser()
return v, nil
@@ -1819,6 +1996,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.MulticastLoopOption:
+ return true, nil
+
+ case tcpip.AcceptConnOption:
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return e.EndpointState() == StateListen, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1853,6 +2039,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := header.TCPDefaultMSS
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1886,21 +2077,17 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.UnlockUser()
return v, nil
+ case tcpip.MulticastTTLOption:
+ return 1, nil
+
default:
return -1, tcpip.ErrUnknownProtocolOption
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
-
case *tcpip.BindToDeviceOption:
e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
@@ -1952,6 +2139,24 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
e.UnlockUser()
+ case *tcpip.OriginalDestinationOption:
+ e.LockUser()
+ ipt := e.stack.IPTables()
+ addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
+ e.UnlockUser()
+ if err != nil {
+ return err
+ }
+ *o = tcpip.OriginalDestinationOption{
+ Addr: addr,
+ Port: port,
+ }
+
+ case *tcpip.LingerOption:
+ e.LockUser()
+ *o = e.linger
+ e.UnlockUser()
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2049,8 +2254,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
defer r.Release()
- origID := e.ID
-
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.ID.LocalAddress = r.LocalAddress
e.ID.RemoteAddress = r.RemoteAddress
@@ -2058,7 +2261,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2081,43 +2284,91 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
h.Write(portBuf)
portOffset := h.Sum32()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err))
+ }
+
+ reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal
+ if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
+ switch netProto {
+ case header.IPv4ProtocolNumber:
+ reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ case header.IPv6ProtocolNumber:
+ reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ }
+ }
+
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- // reusePort is false below because connect cannot reuse a port even if
- // reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
- return false, nil
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if err != tcpip.ErrPortInUse || !reuse {
+ return false, nil
+ }
+ transEPID := e.ID
+ transEPID.LocalPort = p
+ // Check if an endpoint is registered with demuxer in TIME-WAIT and if
+ // we can reuse it. If we can't find a transport endpoint then we just
+ // skip using this port as it's possible that either an endpoint has
+ // bound the port but not registered with demuxer yet (no listen/connect
+ // done yet) or the reservation was freed between the check above and
+ // the FindTransportEndpoint below. But rather than retry the same port
+ // we just skip it and move on.
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ if transEP == nil {
+ // ReservePort failed but there is no registered endpoint with
+ // demuxer. Which indicates there is at least some endpoint that has
+ // bound the port.
+ return false, nil
+ }
+
+ tcpEP := transEP.(*endpoint)
+ tcpEP.LockUser()
+ // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
+ // less than 1 second has elapsed since its recentTS was updated then
+ // we cannot reuse the port.
+ if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ tcpEP.UnlockUser()
+ return false, nil
+ }
+ // Since the endpoint is in TIME-WAIT it should be safe to acquire its
+ // Lock while holding the lock for this endpoint as endpoints in
+ // TIME-WAIT do not acquire locks on other endpoints.
+ tcpEP.workerCleanup = false
+ tcpEP.cleanupLocked()
+ tcpEP.notifyProtocolGoroutine(notifyAbort)
+ tcpEP.UnlockUser()
+ // Now try and Reserve again if it fails then we skip.
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ return false, nil
+ }
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
- case nil:
- // Port picking successful. Save the details of
- // the selected port.
- e.ID = id
- e.boundBindToDevice = e.bindToDevice
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err == tcpip.ErrPortInUse {
+ return false, nil
+ }
return false, err
}
+
+ // Port picking successful. Save the details of
+ // the selected port.
+ e.ID = id
+ e.isPortReserved = true
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.boundDest = addr
+ return true, nil
}); err != nil {
return err
}
}
- // Remove the port reservation. This can happen when Bind is called
- // before Connect: in such a case we don't want to hold on to
- // reservations anymore.
- if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
- e.isPortReserved = false
- }
-
e.isRegistered = true
e.setEndpointState(StateConnecting)
e.route = r.Clone()
@@ -2296,7 +2547,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
@@ -2330,7 +2581,9 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+//
+// addr if not-nil will contain the peer address of the returned endpoint.
+func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2352,6 +2605,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
+ if peerAddr != nil {
+ *peerAddr = n.getRemoteAddress()
+ }
return n, n.waiterQueue, nil
}
@@ -2388,46 +2644,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
+ var nic tcpip.NICID
+ // If an address is specified, we must ensure that it's one of our
+ // local addresses.
+ if len(addr.Addr) != 0 {
+ nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ e.ID.LocalAddress = addr.Addr
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
+
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ id := e.ID
+ id.LocalPort = p
+ // CheckRegisterTransportEndpoint should only return an error if there is a
+ // listening endpoint bound with the same id and portFlags and bindToDevice
+ // options.
+ //
+ // NOTE: Only listening and connected endpoint register with
+ // demuxer. Further connected endpoints always have a remote
+ // address/port. Hence this will only return an error if there is a matching
+ // listening endpoint.
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ return false
+ }
+ return true
+ })
if err != nil {
return err
}
e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = flags
+ e.boundPortFlags = e.portFlags
+ // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
+ e.boundNICID = nic
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
- // Any failures beyond this point must remove the port registration.
- defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
- if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
- e.isPortReserved = false
- e.effectiveNetProtos = nil
- e.ID.LocalPort = 0
- e.ID.LocalAddress = ""
- e.boundNICID = 0
- e.boundBindToDevice = 0
- e.boundPortFlags = ports.Flags{}
- }
- }(e.boundPortFlags, e.boundBindToDevice)
-
- // If an address is specified, we must ensure that it's one of our
- // local addresses.
- if len(addr.Addr) != 0 {
- nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nic == 0 {
- return tcpip.ErrBadLocalAddress
- }
-
- e.boundNICID = nic
- e.ID.LocalAddress = addr.Addr
- }
-
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2455,14 +2710,18 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
+ return e.getRemoteAddress(), nil
+}
+
+func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
return tcpip.FullAddress{
Addr: e.ID.RemoteAddress,
Port: e.ID.RemotePort,
NIC: e.boundNICID,
- }, nil
+ }
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -2481,7 +2740,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2492,6 +2751,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
@@ -2518,13 +2789,8 @@ func (e *endpoint) updateSndBufferUsage(v int) {
func (e *endpoint) readyToRead(s *segment) {
e.rcvListMu.Lock()
if s != nil {
+ e.rcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvBufUsed += s.data.Size()
- // Increase counter if the receive window falls down below MSS
- // or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -2539,15 +2805,17 @@ func (e *endpoint) readyToRead(s *segment) {
func (e *endpoint) receiveBufferAvailableLocked() int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
- if e.rcvBufUsed >= e.rcvBufSize {
+ memUsed := e.receiveMemUsed()
+ if memUsed >= e.rcvBufSize {
return 0
}
- return e.rcvBufSize - e.rcvBufUsed
+ return e.rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
-// receive buffer.
+// receive buffer based on the actual memory used by all segments held in
+// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvListMu.Lock()
available := e.receiveBufferAvailableLocked()
@@ -2555,16 +2823,37 @@ func (e *endpoint) receiveBufferAvailable() int {
return available
}
+// receiveBufferUsed returns the amount of in-use receive buffer.
+func (e *endpoint) receiveBufferUsed() int {
+ e.rcvListMu.Lock()
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+ return used
+}
+
+// receiveBufferSize returns the current size of the receive buffer.
func (e *endpoint) receiveBufferSize() int {
e.rcvListMu.Lock()
size := e.rcvBufSize
e.rcvListMu.Unlock()
-
return size
}
+// receiveMemUsed returns the total memory in use by segments held by this
+// endpoint.
+func (e *endpoint) receiveMemUsed() int {
+ return int(atomic.LoadInt32(&e.rcvMemUsed))
+}
+
+// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed.
+func (e *endpoint) updateReceiveMemUsed(delta int) {
+ atomic.AddInt32(&e.rcvMemUsed, int32(delta))
+}
+
+// maxReceiveBufferSize returns the stack wide maximum receive buffer size for
+// an endpoint.
func (e *endpoint) maxReceiveBufferSize() int {
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
// As a fallback return the hardcoded max buffer size.
return MaxBufferSize
@@ -2611,15 +2900,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.tsOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(offset uint32) uint32 {
- now := time.Now()
- return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
+ return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
@@ -2645,7 +2933,7 @@ func timeStampOffset() uint32 {
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
- var v SACKEnabled
+ var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return
@@ -2714,7 +3002,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
RcvAcc: e.rcv.rcvAcc,
RcvWndScale: e.rcv.rcvWndScale,
PendingBufUsed: e.rcv.pendingBufUsed,
- PendingBufSize: e.rcv.pendingBufSize,
}
// Copy sender state.
@@ -2762,6 +3049,15 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
WEst: cubic.wEst,
}
}
+
+ rc := e.snd.rc
+ s.Sender.RACKState = stack.TCPRACKState{
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ Reord: rc.reorderSeen,
+ }
return s
}
@@ -2830,8 +3126,3 @@ func (e *endpoint) Wait() {
<-notifyCh
}
}
-
-func mssForRoute(r *stack.Route) uint16 {
- // TODO(b/143359391): Respect TCP Min and Max size.
- return uint16(r.MTU() - header.TCPMinimumSize)
-}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index fc43c11e2..b25431467 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -44,16 +44,15 @@ func (e *endpoint) drainSegmentLocked() {
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
- e.segmentQueue.setLimit(0)
+ e.segmentQueue.freeze()
e.mu.Lock()
defer e.mu.Unlock()
- switch e.EndpointState() {
- case StateInitial, StateBound:
- // TODO(b/138137272): this enumeration duplicates
- // EndpointState.connected. remove it.
- case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
@@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case StateListen, StateConnecting:
+ case epState == StateListen || epState == StateConnecting:
e.drainSegmentLocked()
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
- break
}
- case StateError, StateClose:
+ case epState.closed():
for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
@@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() {
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
-
- if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
- panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
- }
}
// saveAcceptedChan is invoked by stateify.
@@ -148,23 +144,23 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state EndpointState) {
+func (e *endpoint) loadState(epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
// For restore purposes we treat TimeWait like a connected endpoint.
- if state.connected() || state == StateTimeWait {
+ if epState.connected() || epState == StateTimeWait {
connectedLoading.Add(1)
}
- switch state {
- case StateListen:
+ switch {
+ case epState == StateListen:
listenLoading.Add(1)
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
// as the endpoint is still being loaded and the stack reference is not
// yet initialized.
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
}
// afterLoad is invoked by stateify.
@@ -182,34 +178,41 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
- state := e.origEndpointState
- switch state {
+ e.segmentQueue.thaw()
+ epState := e.origEndpointState
+ switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
- var ss SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- if len(e.BindAddr) == 0 {
- e.BindAddr = e.ID.LocalAddress
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ if err != nil {
+ panic("unable to parse BindAddr: " + err.String())
}
- addr := e.BindAddr
- port := e.ID.LocalPort
- if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
- panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
+ if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
+ e.isPortReserved = true
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
}
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ switch {
+ case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
e.connectingAddress = e.ID.RemoteAddress
@@ -232,13 +235,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
closed := e.closed
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
- if state == StateFinWait2 && closed {
+ if epState == StateFinWait2 && closed {
// If the endpoint has been closed then make sure we notify so
// that the FIN_WAIT2 timer is started after a restore.
e.notifyProtocolGoroutine(notifyClose)
}
connectedLoading.Done()
- case StateListen:
+ case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -255,7 +258,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -267,7 +270,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateBound:
+ case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -276,27 +279,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind()
tcpip.AsyncLoading.Done()
}()
- case StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.setEndpointState(StateClose)
- tcpip.AsyncLoading.Done()
- }()
- }
+ case epState == StateClose:
+ e.isPortReserved = false
e.state = StateClose
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
- case StateError:
+ case epState == StateError:
e.state = StateError
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
-
}
// saveLastError is invoked by stateify.
@@ -317,6 +309,16 @@ func (e *endpoint) loadLastError(s string) {
e.lastError = tcpip.StringToError(s)
}
+// saveRecentTSTime is invoked by stateify.
+func (e *endpoint) saveRecentTSTime() unixTime {
+ return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
+}
+
+// loadRecentTSTime is invoked by stateify.
+func (e *endpoint) loadRecentTSTime(unix unixTime) {
+ e.recentTSTime = time.Unix(unix.second, unix.nano)
+}
+
// saveHardError is invoked by stateify.
func (e *EndpointInfo) saveHardError() string {
if e.HardError == nil {
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 704d01c64..070b634b4 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a2a7ddeb..5bce73605 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -12,12 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package tcp contains the implementation of the TCP transport protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.NewProtocol() as one of the
-// transport protocols when calling stack.New(). Then endpoints can be created
-// by passing tcp.ProtocolNumber as the transport protocol number when calling
-// Stack.NewEndpoint().
+// Package tcp contains the implementation of the TCP transport protocol.
package tcp
import (
@@ -29,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -61,6 +57,10 @@ const (
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
+ // MaxTCPLingerTimeout is the maximum amount of time that sockets
+ // linger in FIN_WAIT_2 state before being marked closed.
+ MaxTCPLingerTimeout = 120 * time.Second
+
// DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
// in TIME_WAIT state before being marked closed.
DefaultTCPTimeWaitTimeout = 60 * time.Second
@@ -70,29 +70,6 @@ const (
DefaultSynRetries = 6
)
-// SACKEnabled option can be used to enable SACK support in the TCP
-// protocol. See: https://tools.ietf.org/html/rfc2018.
-type SACKEnabled bool
-
-// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
-type DelayEnabled bool
-
-// SendBufferSizeOption allows the default, min and max send buffer sizes for
-// TCP endpoints to be queried or configured.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
-// ReceiveBufferSizeOption allows the default, min and max receive buffer size
-// for TCP endpoints to be queried or configured.
-type ReceiveBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
const (
ccReno = "reno"
ccCubic = "cubic"
@@ -156,22 +133,26 @@ func (s *synRcvdCounter) Threshold() uint64 {
}
type protocol struct {
+ stack *stack.Stack
+
mu sync.RWMutex
sackEnabled bool
+ recovery tcpip.TCPRecovery
delayEnabled bool
- sendBufferSize SendBufferSizeOption
- recvBufferSize ReceiveBufferSizeOption
+ sendBufferSize tcpip.TCPSendBufferSizeRangeOption
+ recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
- tcpLingerTimeout time.Duration
- tcpTimeWaitTimeout time.Duration
+ lingerTimeout time.Duration
+ timeWaitTimeout time.Duration
+ timeWaitReuse tcpip.TCPTimeWaitReuseOption
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
synRcvdCount synRcvdCounter
synRetries uint8
- dispatcher *dispatcher
+ dispatcher dispatcher
}
// Number returns the tcp protocol number.
@@ -180,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid tcp packet size.
@@ -206,7 +187,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
p.dispatcher.queuePacket(r, ep, id, pkt)
}
@@ -217,21 +198,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
- return false
+ return stack.UnknownDestinationPacketMalformed
}
- // There's nothing to do if this is already a reset packet.
- if s.flagIsSet(header.TCPFlagRst) {
- return true
+ if !s.flagIsSet(header.TCPFlagRst) {
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
}
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
- return true
+ return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
@@ -269,43 +249,49 @@ func replyWithReset(s *segment, tos, ttl uint8) {
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case SACKEnabled:
+ case *tcpip.TCPSACKEnabled:
p.mu.Lock()
- p.sackEnabled = bool(v)
+ p.sackEnabled = bool(*v)
p.mu.Unlock()
return nil
- case DelayEnabled:
+ case *tcpip.TCPRecovery:
p.mu.Lock()
- p.delayEnabled = bool(v)
+ p.recovery = *v
p.mu.Unlock()
return nil
- case SendBufferSizeOption:
+ case *tcpip.TCPDelayEnabled:
+ p.mu.Lock()
+ p.delayEnabled = bool(*v)
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPSendBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.sendBufferSize = v
+ p.sendBufferSize = *v
p.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case *tcpip.TCPReceiveBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.recvBufferSize = v
+ p.recvBufferSize = *v
p.mu.Unlock()
return nil
- case tcpip.CongestionControlOption:
+ case *tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
- if string(v) == c {
+ if string(*v) == c {
p.mu.Lock()
- p.congestionControl = string(v)
+ p.congestionControl = string(*v)
p.mu.Unlock()
return nil
}
@@ -314,66 +300,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
// is specified.
return tcpip.ErrNoSuchFile
- case tcpip.ModerateReceiveBufferOption:
+ case *tcpip.TCPModerateReceiveBufferOption:
p.mu.Lock()
- p.moderateReceiveBuffer = bool(v)
+ p.moderateReceiveBuffer = bool(*v)
p.mu.Unlock()
return nil
- case tcpip.TCPLingerTimeoutOption:
- if v < 0 {
- v = 0
- }
+ case *tcpip.TCPLingerTimeoutOption:
p.mu.Lock()
- p.tcpLingerTimeout = time.Duration(v)
+ if *v < 0 {
+ p.lingerTimeout = 0
+ } else {
+ p.lingerTimeout = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPTimeWaitTimeoutOption:
- if v < 0 {
- v = 0
- }
+ case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.Lock()
- p.tcpTimeWaitTimeout = time.Duration(v)
+ if *v < 0 {
+ p.timeWaitTimeout = 0
+ } else {
+ p.timeWaitTimeout = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPMinRTOOption:
- if v < 0 {
- v = tcpip.TCPMinRTOOption(MinRTO)
+ case *tcpip.TCPTimeWaitReuseOption:
+ if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.minRTO = time.Duration(v)
+ p.timeWaitReuse = *v
p.mu.Unlock()
return nil
- case tcpip.TCPMaxRTOOption:
- if v < 0 {
- v = tcpip.TCPMaxRTOOption(MaxRTO)
+ case *tcpip.TCPMinRTOOption:
+ p.mu.Lock()
+ if *v < 0 {
+ p.minRTO = MinRTO
+ } else {
+ p.minRTO = time.Duration(*v)
}
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
- p.maxRTO = time.Duration(v)
+ if *v < 0 {
+ p.maxRTO = MaxRTO
+ } else {
+ p.maxRTO = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPMaxRetriesOption:
+ case *tcpip.TCPMaxRetriesOption:
p.mu.Lock()
- p.maxRetries = uint32(v)
+ p.maxRetries = uint32(*v)
p.mu.Unlock()
return nil
- case tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPSynRcvdCountThresholdOption:
p.mu.Lock()
- p.synRcvdCount.SetThreshold(uint64(v))
+ p.synRcvdCount.SetThreshold(uint64(*v))
p.mu.Unlock()
return nil
- case tcpip.TCPSynRetriesOption:
- if v < 1 || v > 255 {
+ case *tcpip.TCPSynRetriesOption:
+ if *v < 1 || *v > 255 {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.synRetries = uint8(v)
+ p.synRetries = uint8(*v)
p.mu.Unlock()
return nil
@@ -383,27 +382,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *SACKEnabled:
+ case *tcpip.TCPSACKEnabled:
p.mu.RLock()
- *v = SACKEnabled(p.sackEnabled)
+ *v = tcpip.TCPSACKEnabled(p.sackEnabled)
p.mu.RUnlock()
return nil
- case *DelayEnabled:
+ case *tcpip.TCPRecovery:
p.mu.RLock()
- *v = DelayEnabled(p.delayEnabled)
+ *v = tcpip.TCPRecovery(p.recovery)
p.mu.RUnlock()
return nil
- case *SendBufferSizeOption:
+ case *tcpip.TCPDelayEnabled:
+ p.mu.RLock()
+ *v = tcpip.TCPDelayEnabled(p.delayEnabled)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSendBufferSizeRangeOption:
p.mu.RLock()
*v = p.sendBufferSize
p.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.TCPReceiveBufferSizeRangeOption:
p.mu.RLock()
*v = p.recvBufferSize
p.mu.RUnlock()
@@ -415,27 +420,33 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
- case *tcpip.AvailableCongestionControlOption:
+ case *tcpip.TCPAvailableCongestionControlOption:
p.mu.RLock()
- *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.RUnlock()
return nil
- case *tcpip.ModerateReceiveBufferOption:
+ case *tcpip.TCPModerateReceiveBufferOption:
p.mu.RLock()
- *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
+ *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer)
p.mu.RUnlock()
return nil
case *tcpip.TCPLingerTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout)
p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitReuseOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
p.mu.RUnlock()
return nil
@@ -490,20 +501,37 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
return &p.synRcvdCount
}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ return parse.TCP(pkt)
+}
+
// NewProtocol returns a TCP transport protocol.
-func NewProtocol() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ p := protocol{
+ stack: s,
+ sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
- tcpLingerTimeout: DefaultTCPLingerTimeout,
- tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ lingerTimeout: DefaultTCPLingerTimeout,
+ timeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
- dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
+ recovery: tcpip.TCPRACKLossDetection,
}
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
new file mode 100644
index 000000000..d312b1b8b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// RACK is a loss detection algorithm used in TCP to detect packet loss and
+// reordering using transmission timestamp of the packets instead of packet or
+// sequence counts. To use RACK, SACK should be enabled on the connection.
+
+// rackControl stores the rack related fields.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1
+//
+// +stateify savable
+type rackControl struct {
+ // endSequence is the ending TCP sequence number of rackControl.seg.
+ endSequence seqnum.Value
+
+ // dsack indicates if the connection has seen a DSACK.
+ dsack bool
+
+ // fack is the highest selectively or cumulatively acknowledged
+ // sequence.
+ fack seqnum.Value
+
+ // minRTT is the estimated minimum RTT of the connection.
+ minRTT time.Duration
+
+ // rtt is the RTT of the most recently delivered packet on the
+ // connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ rtt time.Duration
+
+ // reorderSeen indicates if reordering has been detected on this
+ // connection.
+ reorderSeen bool
+
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
+}
+
+// update will update the RACK related fields when an ACK has been received.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) {
+ rtt := time.Now().Sub(seg.xmitTime)
+
+ // If the ACK is for a retransmitted packet, do not update if it is a
+ // spurious inference which is determined by below checks:
+ // 1. When Timestamping option is available, if the TSVal is less than the
+ // transmit time of the most recent retransmitted packet.
+ // 2. When RTT calculated for the packet is less than the smoothed RTT
+ // for the connection.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // step 2
+ if seg.xmitCount > 1 {
+ if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ return
+ }
+ }
+ if rtt < rc.minRTT {
+ return
+ }
+ }
+
+ rc.rtt = rtt
+
+ // The sender can either track a simple global minimum of all RTT
+ // measurements from the connection, or a windowed min-filtered value
+ // of recent RTT measurements. This implementation keeps track of the
+ // simple global minimum of all RTTs for the connection.
+ if rtt < rc.minRTT || rc.minRTT == 0 {
+ rc.minRTT = rtt
+ }
+
+ // Update rc.xmitTime and rc.endSequence to the transmit time and
+ // ending sequence number of the packet which has been acknowledged
+ // most recently.
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
+ rc.xmitTime = seg.xmitTime
+ rc.endSequence = endSeq
+ }
+}
+
+// detectReorder detects if packet reordering has been observed.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// * Step 3: Detect data segment reordering.
+// To detect reordering, the sender looks for original data segments being
+// delivered out of order. To detect such cases, the sender tracks the
+// highest sequence selectively or cumulatively acknowledged in the RACK.fack
+// variable. The name "fack" stands for the most "Forward ACK" (this term is
+// adopted from [FACK]). If a never retransmitted segment that's below
+// RACK.fack is (selectively or cumulatively) acknowledged, it has been
+// delivered out of order. The sender sets RACK.reord to TRUE if such segment
+// is identified.
+func (rc *rackControl) detectReorder(seg *segment) {
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.fack.LessThan(endSeq) {
+ rc.fack = endSeq
+ return
+ }
+
+ if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 {
+ rc.reorderSeen = true
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
new file mode 100644
index 000000000..c9dc7e773
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveXmitTime is invoked by stateify.
+func (rc *rackControl) saveXmitTime() unixTime {
+ return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (rc *rackControl) loadXmitTime(unix unixTime) {
+ rc.xmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..8e0b7c843 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -43,26 +43,32 @@ type receiver struct {
// rcvWnd is the non-scaled receive window last advertised to the peer.
rcvWnd seqnum.Size
+ // rcvWUP is the rcvNxt value at the last window update sent.
+ rcvWUP seqnum.Value
+
rcvWndScale uint8
closed bool
+ // pendingRcvdSegments is bounded by the receive buffer size of the
+ // endpoint.
pendingRcvdSegments segmentHeap
- pendingBufUsed seqnum.Size
- pendingBufSize seqnum.Size
+ // pendingBufUsed tracks the total number of bytes (including segment
+ // overhead) currently queued in pendingRcvdSegments.
+ pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
+ rcvWUP: irs + 1,
rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
lastRcvdAckTime: time.Now(),
}
}
@@ -82,19 +88,54 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
}
+// currentWindow returns the available space in the window that was advertised
+// last to our peer.
+func (r *receiver) currentWindow() (curWnd seqnum.Size) {
+ endOfWnd := r.rcvWUP.Add(r.rcvWnd)
+ if endOfWnd.LessThan(r.rcvNxt) {
+ // return 0 if r.rcvNxt is past the end of the previously advertised window.
+ // This can happen because we accept a large segment completely even if
+ // accepting it causes it to partially exceed the advertised window.
+ return 0
+ }
+ return r.rcvNxt.Size(endOfWnd)
+}
+
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the available buffer space.
- receiveBufferAvailable := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
- if r.rcvAcc.LessThan(acc) {
- r.rcvAcc = acc
+ newWnd := r.ep.selectWindow()
+ curWnd := r.currentWindow()
+ // Update rcvAcc only if new window is > previously advertised window. We
+ // should never shrink the acceptable sequence space once it has been
+ // advertised the peer. If we shrink the acceptable sequence space then we
+ // would end up dropping bytes that might already be in flight.
+ // ==================================================== sequence space.
+ // ^ ^ ^ ^
+ // rcvWUP rcvNxt rcvAcc new rcvAcc
+ // <=====curWnd ===>
+ // <========= newWnd > curWnd ========= >
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ // If the new window moves the right edge, then update rcvAcc.
+ r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
+ } else {
+ if newWnd == 0 {
+ // newWnd is zero but we can't advertise a zero as it would cause window
+ // to shrink so just increment a metric to record this event.
+ r.ep.stats.ReceiveErrors.WantZeroRcvWindow.Increment()
+ }
+ newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
- r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
- return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
+ r.rcvWnd = newWnd
+ r.rcvWUP = r.rcvNxt
+ scaledWnd := r.rcvWnd >> r.rcvWndScale
+ if scaledWnd == 0 {
+ // Increment a metric if we are advertising an actual zero window.
+ r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+ return r.rcvNxt, scaledWnd
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
@@ -195,7 +236,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
+
// Note that slice truncation does not allow garbage collection of
// truncated items, thus truncated items must be set to nil to avoid
// memory leaks.
@@ -268,14 +311,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// If we are in one of the shutdown states then we need to do
// additional checks before we try and process the segment.
switch state {
- case StateCloseWait:
- // If the ACK acks something not yet sent then we send an ACK.
- if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
- r.ep.snd.sendAck()
- return true, nil
- }
- fallthrough
- case StateClosing, StateLastAck:
+ case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
// Just drop the segment as we have
// already received a FIN and this
@@ -284,9 +320,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
return true, nil
}
fallthrough
- case StateFinWait1:
- fallthrough
- case StateFinWait2:
+ case StateFinWait1, StateFinWait2:
+ // If the ACK acks something not yet sent then we send an ACK.
+ //
+ // RFC793, page 37: If the connection is in a synchronized state,
+ // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK,
+ // TIME-WAIT), any unacceptable segment (out of window sequence number
+ // or unacceptable acknowledgment number) must elicit only an empty
+ // acknowledgment segment containing the current send-sequence number
+ // and an acknowledgment indicating the next sequence number expected
+ // to be received, and the connection remains in the same state.
+ //
+ // Just as on Linux, we do not apply this behavior when state is
+ // ESTABLISHED.
+ // Linux receive processing for all states except ESTABLISHED and
+ // TIME_WAIT is here where if the ACK check fails, we attempt to
+ // reply back with an ACK with correct seq/ack numbers.
+ // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186
+ // The ESTABLISHED state processing is here where if the ACK check
+ // fails, we ignore the packet:
+ // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591
+ if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ r.ep.snd.sendAck()
+ return true, nil
+ }
+
// If we are closed for reads (either due to an
// incoming FIN or the user calling shutdown(..,
// SHUT_RD) then any data past the rcvNxt should
@@ -369,10 +427,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
- // We only store the segment if it's within our buffer
- // size limit.
- if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ // We only store the segment if it's within our buffer size limit.
+ //
+ // Only use 75% of the receive buffer queue for out-of-order
+ // segments. This ensures that we always leave some space for the inorder
+ // segments to arrive allowing pending segments to be processed and
+ // delivered to the user.
+ if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed += s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +470,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed -= s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.decRef()
}
return false, nil
@@ -421,6 +487,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// Just silently drop any RST packets in TIME_WAIT. We do not support
// TIME_WAIT assasination as a result we confirm w/ fix 1 as described
// in https://tools.ietf.org/html/rfc1337#section-3.
+ //
+ // This behavior overrides RFC793 page 70 where we transition to CLOSED
+ // on receiving RST, which is also default Linux behavior.
+ // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337.
+ //
+ // As we do not yet support PAWS, we are being conservative in ignoring
+ // RSTs by default.
if s.flagIsSet(header.TCPFlagRst) {
return false, false
}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
index 7ef2df377..833a7b470 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
return found
}
-// Dump prints the state of the scoreboard structure.
+// String returns human-readable state of the scoreboard structure.
func (s *SACKScoreboard) String() string {
var str strings.Builder
str.WriteString("SACKScoreboard: {")
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 074edded6..1f9c5cf50 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"sync/atomic"
"time"
@@ -24,6 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// queueFlags are used to indicate which queue of an endpoint a particular segment
+// belongs to. This is used to track memory accounting correctly.
+type queueFlags uint8
+
+const (
+ recvQ queueFlags = 1 << iota
+ sendQ
+)
+
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
@@ -32,9 +42,12 @@ import (
type segment struct {
segmentEntry
refCnt int32
+ ep *endpoint
+ qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -58,15 +71,19 @@ type segment struct {
// xmitTime is the last transmit time of this segment.
xmitTime time.Time `state:".(unixTime)"`
xmitCount uint32
+
+ // acked indicates if the segment has already been SACKed.
+ acked bool
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.data = pkt.Data.Clone(s.views[:])
+ s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
return s
}
@@ -98,6 +115,8 @@ func (s *segment) clone() *segment {
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
+ ep: s.ep,
+ qFlags: s.qFlags,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -113,8 +132,34 @@ func (s *segment) flagsAreSet(flags uint8) bool {
return s.flags&flags == flags
}
+// setOwner sets the owning endpoint for this segment. Its required
+// to be called to ensure memory accounting for receive/send buffer
+// queues is done properly.
+func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) {
+ switch qFlags {
+ case recvQ:
+ ep.updateReceiveMemUsed(s.segMemSize())
+ case sendQ:
+ // no memory account for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
+ }
+ s.ep = ep
+ s.qFlags = qFlags
+}
+
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ if s.ep != nil {
+ switch s.qFlags {
+ case recvQ:
+ s.ep.updateReceiveMemUsed(-s.segMemSize())
+ case sendQ:
+ // no memory accounting for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
+ }
+ }
s.route.Release()
}
}
@@ -136,6 +181,17 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// payloadSize is the size of s.data.
+func (s *segment) payloadSize() int {
+ return s.data.Size()
+}
+
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
@@ -146,12 +202,6 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h, ok := s.data.PullUp(header.TCPMinimumSize)
- if !ok {
- return false
- }
- hdr := header.TCP(h)
-
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -162,16 +212,12 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(hdr.DataOffset())
- if offset < header.TCPMinimumSize {
- return false
- }
- hdrWithOpts, ok := s.data.PullUp(offset)
- if !ok {
+ offset := int(s.hdr.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(s.hdr) {
return false
}
- s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
+ s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -180,22 +226,19 @@ func (s *segment) parse() bool {
if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
s.csumValid = true
verifyChecksum = false
- s.data.TrimFront(offset)
}
if verifyChecksum {
- hdr = header.TCP(hdrWithOpts)
- s.csum = hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = hdr.CalculateChecksum(xsum)
- s.data.TrimFront(offset)
+ s.csum = s.hdr.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
- s.ackNumber = seqnum.Value(hdr.AckNumber())
- s.flags = hdr.Flags()
- s.window = seqnum.Size(hdr.WindowSize())
+ s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(s.hdr.AckNumber())
+ s.flags = s.hdr.Flags()
+ s.window = seqnum.Size(s.hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 48a257137..54545a1b1 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -22,16 +22,16 @@ import (
//
// +stateify savable
type segmentQueue struct {
- mu sync.Mutex `state:"nosave"`
- list segmentList `state:"wait"`
- limit int
- used int
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ ep *endpoint
+ frozen bool
}
// emptyLocked determines if the queue is empty.
// Preconditions: q.mu must be held.
func (q *segmentQueue) emptyLocked() bool {
- return q.used == 0
+ return q.list.Empty()
}
// empty determines if the queue is empty.
@@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool {
return r
}
-// setLimit updates the limit. No segments are immediately dropped in case the
-// queue becomes full due to the new limit.
-func (q *segmentQueue) setLimit(limit int) {
- q.mu.Lock()
- q.limit = limit
- q.mu.Unlock()
-}
-
// enqueue adds the given segment to the queue.
//
// Returns true when the segment is successfully added to the queue, in which
@@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) {
// false if the queue is full, in which case ownership is retained by the
// caller.
func (q *segmentQueue) enqueue(s *segment) bool {
+ // q.ep.receiveBufferParams() must be called without holding q.mu to
+ // avoid lock order inversion.
+ bufSz := q.ep.receiveBufferSize()
+ used := q.ep.receiveMemUsed()
q.mu.Lock()
- r := q.used < q.limit
- if r {
+ // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
+ // is currently full).
+ allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+
+ if allow {
q.list.PushBack(s)
- q.used++
+ // Set the owner now that the endpoint owns the segment.
+ s.setOwner(q.ep, recvQ)
}
q.mu.Unlock()
- return r
+ return allow
}
// dequeue removes and returns the next segment from queue, if one exists.
@@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment {
s := q.list.Front()
if s != nil {
q.list.Remove(s)
- q.used--
}
q.mu.Unlock()
return s
}
+
+// freeze prevents any more segments from being added to the queue. i.e all
+// future segmentQueue.enqueue will return false and not add the segment to the
+// queue till the queue is unfroze with a corresponding segmentQueue.thaw call.
+func (q *segmentQueue) freeze() {
+ q.mu.Lock()
+ q.frozen = true
+ q.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and
+// allows new segments to be queued again.
+func (q *segmentQueue) thaw() {
+ q.mu.Lock()
+ q.frozen = false
+ q.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 06dc9b7d7..6fa8d63cd 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "sort"
"sync/atomic"
"time"
@@ -191,6 +192,10 @@ type sender struct {
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
+
+ // rc has the fields needed for implementing RACK loss detection
+ // algorithm.
+ rc rackControl
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -259,6 +264,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
highRxt: iss,
rescueRxt: iss,
},
+ rc: rackControl{
+ fack: iss,
+ },
gso: ep.gso != nil,
}
@@ -618,6 +626,20 @@ func (s *sender) splitSeg(seg *segment, size int) {
nSeg.data.TrimFront(size)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
seg.data.CapLength(size)
}
@@ -739,7 +761,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(seg.sequenceNumber.Size(end))
+ available := int(s.sndNxt.Size(end))
if available > limit {
available = limit
}
@@ -782,8 +804,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// sent all at once.
return false
}
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
return false
}
}
@@ -824,10 +849,52 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if available == 0 {
return false
}
+
+ // If the whole segment or at least 1MSS sized segment cannot
+ // be accomodated in the receiver advertized window, skip
+ // splitting and sending of the segment. ref:
+ // net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered by
+ // a probe timer. On this condition, it defers the segment split
+ // and transmit to a short probe timer.
+ //
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split
+ // the segment right here if there are no pending segments. If
+ // there are pending segments, segment transmits are deferred to
+ // the retransmit timer handler.
+ if s.sndUna != s.sndNxt {
+ switch {
+ case available >= seg.data.Size():
+ // OK to send, the whole segments fits in the
+ // receiver's advertised window.
+ case available >= s.maxPayloadSize:
+ // OK to send, at least 1 MSS sized segment fits
+ // in the receiver's advertised window.
+ default:
+ return false
+ }
+ }
+
+ // The segment size limit is computed as a function of sender
+ // congestion window and MSS. When sender congestion window is >
+ // 1, this limit can be larger than MSS. Ensure that the
+ // currently available send space is not greater than minimum of
+ // this limit and MSS.
if available > limit {
available = limit
}
+ // If GSO is not in use then cap available to
+ // maxPayloadSize. When GSO is in use the gVisor GSO logic or
+ // the host GSO logic will cap the segment to the correct size.
+ if s.ep.gso == nil && available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
+
if seg.data.Size() > available {
s.splitSeg(seg, available)
}
@@ -1211,23 +1278,56 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
return true
}
+// Iterate the writeList and update RACK for each segment which is newly acked
+// either cumulatively or selectively. Loop through the segments which are
+// sacked, and update the RACK related variables and check for reordering.
+//
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// steps 2 and 3.
+func (s *sender) walkSACK(rcvdSeg *segment) {
+ // Sort the SACK blocks. The first block is the most recent unacked
+ // block. The following blocks can be in arbitrary order.
+ sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
+ copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks)
+ sort.Slice(sackBlocks, func(i, j int) bool {
+ return sackBlocks[j].Start.LessThan(sackBlocks[i].Start)
+ })
+
+ seg := s.writeList.Front()
+ for _, sb := range sackBlocks {
+ // This check excludes DSACK blocks.
+ if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) {
+ continue
+ }
+
+ for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
+ if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
+ seg.acked = true
+ }
+ seg = seg.Next()
+ }
+ }
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
-func (s *sender) handleRcvdSegment(seg *segment) {
+func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && seg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
if s.ep.sackPermitted {
- for _, sb := range seg.parsedOptions.SACKBlocks {
+ for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
// * SACK block acks data after the ack number in the
@@ -1240,27 +1340,42 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
- seg.hasNewSACKInfo = true
+ rcvdSeg.hasNewSACKInfo = true
}
}
+
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08
+ // section-7.2
+ // * Step 2: Update RACK stats.
+ // If the ACK is not ignored as invalid, update the RACK.rtt
+ // to be the RTT sample calculated using this ACK, and
+ // continue. If this ACK or SACK was for the most recently
+ // sent packet, then record the RACK.xmit_ts timestamp and
+ // RACK.end_seq sequence implied by this ACK.
+ // * Step 3: Detect packet reordering.
+ // If the ACK selectively or cumulatively acknowledges an
+ // unacknowledged and also never retransmitted sequence below
+ // RACK.fack, then the corresponding packet has been
+ // reordered and RACK.reord is set to TRUE.
+ s.walkSACK(rcvdSeg)
s.SetPipe()
}
// Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(seg)
+ rtx := s.checkDuplicateAck(rcvdSeg)
// Stash away the current window size.
- s.sndWnd = seg.window
+ s.sndWnd = rcvdSeg.window
- ack := seg.ackNumber
+ ack := rcvdSeg.ackNumber
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
// unacknowledged segment.
- if s.zeroWindowProbing && seg.window > 0 &&
+ if s.zeroWindowProbing && rcvdSeg.window > 0 &&
(ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
s.disableZeroWindowProbing()
}
@@ -1285,10 +1400,10 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
- elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
@@ -1321,9 +1436,15 @@ func (s *sender) handleRcvdSegment(seg *segment) {
s.writeNext = seg.Next()
}
+ // Update the RACK fields if SACK is enabled.
+ if s.ep.sackPermitted && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
+ }
+
s.writeList.Remove(seg)
- // if SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then Only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
@@ -1376,7 +1497,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
s.sendData()
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 5fe23113b..b9993ce1a 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -50,7 +50,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -90,14 +90,14 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
}
return nil
}
@@ -128,10 +128,10 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn = func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
return nil
}
@@ -215,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
expected := tcp.InitialCwnd
@@ -257,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -362,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -471,11 +471,11 @@ func TestRetransmit(t *testing.T) {
// MTU size though.
half := data[:len(data)/2]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
half = data[len(data)/2:]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -508,23 +508,23 @@ func TestRetransmit(t *testing.T) {
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
}
return nil
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
new file mode 100644
index 000000000..d3f92b48c
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+const (
+ maxPayload = 10
+ tsOptionSize = 12
+ maxTCPOptionSize = 40
+)
+
+// TestRACKUpdate tests the RACK related fields are updated when an ACK is
+// received on a SACK enabled connection.
+func TestRACKUpdate(t *testing.T) {
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ var xmitTime time.Time
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.RACKState is what we expect.
+ if state.Sender.RACKState.XmitTime.Before(xmitTime) {
+ t.Fatalf("RACK transmit time failed to update when an ACK is received")
+ }
+
+ gotSeq := state.Sender.RACKState.EndSequence
+ wantSeq := state.Sender.SndNxt
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ if state.Sender.RACKState.RTT == 0 {
+ t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
+ }
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ xmitTime = time.Now()
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
+ time.Sleep(200 * time.Millisecond)
+}
+
+// TestRACKDetectReorder tests that RACK detects packet reordering.
+func TestRACKDetectReorder(t *testing.T) {
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ const ackNum = 2
+
+ var n int
+ ch := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ gotSeq := state.Sender.RACKState.FACK
+ wantSeq := state.Sender.SndNxt
+ // FACK should be updated to the highest ending sequence number of the
+ // segment acknowledged most recently.
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ n++
+ if n < ackNum {
+ if state.Sender.RACKState.Reord {
+ t.Fatalf("RACK reorder detected when there is no reordering")
+ }
+ return
+ }
+
+ if state.Sender.RACKState.Reord == false {
+ t.Fatalf("RACK reorder detection failed")
+ }
+ close(ch)
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ data := buffer.NewView(ackNum * maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ for i := 0; i < ackNum; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ start := c.IRS.Add(maxPayload + 1)
+ end := start.Add(maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-ch
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ace79b7b2..ef7f5719f 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ opt := tcpip.TCPSACKEnabled(enable)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) {
// Set the SynRcvd threshold to
// zero to force a syn cookie
// based accept to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
setStackSACKPermitted(t, c, sackEnabled)
@@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) {
// Set the SynRcvd threshold to
// zero to force a syn cookie
// based accept to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -400,7 +403,7 @@ func TestSACKRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -454,7 +457,7 @@ func TestSACKRecovery(t *testing.T) {
}
for _, s := range stats {
if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
}
}
return nil
@@ -529,19 +532,19 @@ func TestSACKRecovery(t *testing.T) {
// In SACK recovery only the first segment is fast retransmitted when
// entering recovery.
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
}
return nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6ef32a1b3..5f05608e2 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -57,7 +58,7 @@ func TestGiveUpConnect(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -66,7 +67,7 @@ func TestGiveUpConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Close the connection, wait for completion.
@@ -74,22 +75,22 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ if err := ep.LastError(); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
// and stats updates.
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -102,7 +103,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -115,10 +116,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
}
}
@@ -129,20 +130,38 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
+ }
+}
+
+func TestCloseWithoutConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ c.EP.Close()
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -156,10 +175,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
}
}
@@ -170,16 +189,16 @@ func TestTCPResetsSentIncrement(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
want := stats.TCP.SegmentsSent.Value() + 1
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -213,7 +232,7 @@ func TestTCPResetsSentIncrement(t *testing.T) {
metricPollFn := func() error {
if got := stats.TCP.ResetsSent.Value(); got != want {
- return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
}
return nil
}
@@ -222,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetsSentNoICMP confirms that we don't get an ICMP
+// DstUnreachable packet when we try send a packet which is not part
+// of an active session.
+func TestTCPResetsSentNoICMP(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ stats := c.Stack().Stats()
+
+ // Send a SYN request for a closed port. This should elicit an RST
+ // but NOT an ICMPv4 DstUnreachable packet.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive whatever comes back.
+ b := c.GetPacket()
+ ipHdr := header.IPv4(b)
+ if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
+ t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
+ }
+
+ // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
+ sent := stats.ICMP.V4PacketsSent
+ if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
+ t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
+ }
+}
+
// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
// a RST if an ACK is received on the listening socket for which there is no
// active handshake in progress and we are not using SYN cookies.
@@ -273,12 +324,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -291,16 +342,16 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Lower stackwide TIME_WAIT timeout so that the reservations
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err)
}
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
SrcPort: context.TestPort,
@@ -330,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -355,7 +406,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
}
@@ -379,7 +430,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
}
@@ -403,7 +454,7 @@ func TestNonBlockingClose(t *testing.T) {
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -414,8 +465,9 @@ func TestConnectResetAfterClose(t *testing.T) {
// Set TCPLinger to 3 seconds so that sockets are marked closed
// after 3 second in FIN_WAIT2 state.
tcpLingerTimeout := 3 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -428,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -470,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence number
// of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -488,8 +540,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
// after 1 second in TIME_WAIT state.
tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -497,11 +550,11 @@ func TestCurrentConnectedIncrement(t *testing.T) {
c.EP = nil
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
}
gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
}
ep.Close()
@@ -509,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -524,10 +577,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
})
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
}
// Ack and send FIN as well.
@@ -545,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -556,10 +609,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -575,7 +628,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
c.EP = nil
if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Send a FIN for ESTABLISHED --> CLOSED-WAIT
@@ -592,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -603,7 +656,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
time.Sleep(10 * time.Millisecond)
if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Close the application endpoint for CLOSE_WAIT --> LAST_ACK
@@ -613,14 +666,14 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Pause the endpoint`s protocolMainLoop.
@@ -657,15 +710,15 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
// Expect the endpoint to be closed.
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
// Check if the endpoint was moved to CLOSED and netstack a reset in
@@ -673,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -691,7 +744,7 @@ func TestSimpleReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -714,7 +767,7 @@ func TestSimpleReceive(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -725,135 +778,234 @@ func TestSimpleReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
-// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
-// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
+// creating a new active TCP socket. It should be present in the sent TCP
// SYN segment.
-func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+func TestUserSuppliedMSSOnConnect(t *testing.T) {
const mtu = 5000
- const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS int
- expMSS uint16
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ connectAddr tcpip.Address
+ checker func(*testing.T, *context.Context, uint16, int)
+ maxMSS uint16
}{
{
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ connectAddr: context.TestAddr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
},
{
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(true)
+ },
+ connectAddr: context.TestV6Addr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
- c.Create(-1)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
- }
+ ip.createEP(c)
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+ // Set the MSS socket option.
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
- // Start connection attempt to IPv4 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
- // Receive SYN packet with our user supplied MSS.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
+ if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ ip.checker(t, c, test.expMSS, ws)
+ })
+ }
})
}
}
-// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
-// creating a new active IPv6 TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
- const mtu = 5000
- const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
+// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
+// when completing the handshake for a new TCP connection from a TCP
+// listening socket. It should be present in the sent TCP SYN-ACK segment.
+func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
+ const (
+ nonSynCookieAccepts = 2
+ totalAccepts = 4
+ mtu = 5000
+ )
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ sendPkt func(*context.Context, *context.Headers)
+ checker func(*testing.T, *context.Context, uint16, uint16)
+ maxMSS uint16
}{
{
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendPacket(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
},
{
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(false)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendV6Packet(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
- c.CreateV6Endpoint(true)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
- }
+ ip.createEP(c)
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+ // Set the SynRcvd threshold to force a syn cookie based accept to happen.
+ opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
- // Start connection attempt to IPv6 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
- // Receive SYN packet with our user supplied MSS.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ bindAddr := tcpip.FullAddress{Port: context.StackPort}
+ if err := c.EP.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s:", bindAddr, err)
+ }
+
+ if err := c.EP.Listen(totalAccepts); err != nil {
+ t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ }
+
+ // The first nonSynCookieAccepts packets sent will trigger a gorooutine
+ // based accept. The rest will trigger a cookie based accept.
+ for i := 0; i < totalAccepts; i++ {
+ // Send a SYN requests.
+ iss := seqnum.Value(i)
+ srcPort := context.TestPort + uint16(i)
+ ip.sendPkt(c, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ ip.checker(t, c, srcPort, test.expMSS)
+ }
+ })
+ }
})
}
}
-
func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -879,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
@@ -907,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
@@ -944,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
@@ -982,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestSendRstOnListenerRxAckV4(t *testing.T) {
@@ -1011,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxAckV6(t *testing.T) {
@@ -1039,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestListenShutdown tests for the listening endpoint replying with RST
@@ -1155,8 +1307,8 @@ func TestTOSV4(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1204,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1232,14 +1384,16 @@ func TestConnectBindToDevice(t *testing.T) {
c.Create(-1)
bindToDevice := tcpip.BindToDeviceOption(test.device)
- c.EP.SetSockOpt(bindToDevice)
+ if err := c.EP.SetSockOpt(&bindToDevice); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err)
+ }
// Start connection attempt.
waitEntry, _ := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -1251,7 +1405,7 @@ func TestConnectBindToDevice(t *testing.T) {
),
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
@@ -1270,74 +1424,97 @@ func TestConnectBindToDevice(t *testing.T) {
c.GetPacket()
if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
})
}
}
-func TestRstOnSynSent(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
+func TestSynSent(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reset bool
+ }{
+ {"RstOnSynSent", true},
+ {"CloseOnSynSent", false},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
- }
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
- // Ensure that we've reached SynSent state
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to Error and close out
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
+ if test.reset {
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to error and close out.
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+ } else {
+ c.EP.Close()
+ }
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ if test.reset {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+ } else {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+ }
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ })
}
}
@@ -1352,7 +1529,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send second half of data first, with seqnum 3 ahead of expected.
@@ -1370,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1379,7 +1556,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send the first 3 bytes now.
@@ -1406,7 +1583,7 @@ func TestOutOfOrderReceive(t *testing.T) {
}
continue
}
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1421,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1432,11 +1609,11 @@ func TestOutOfOrderFlood(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Create a new connection with initial window size of 10.
- c.CreateConnected(789, 30000, 10)
+ rcvBufSz := math.MaxUint16
+ c.CreateConnected(789, 30000, rcvBufSz)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 100 packets before the actual one that is expected.
@@ -1454,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1475,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1495,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(793),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1513,7 +1690,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1537,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1552,11 +1729,11 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
// We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// This final ACK should be ignored because an ACK on a reset doesn't mean
@@ -1582,7 +1759,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1606,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1620,11 +1797,11 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Cause a RST to be generated by closing the read end now since we have
@@ -1639,11 +1816,11 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence
// number of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// The ACK to the FIN should now be rejected since the connection has been
@@ -1665,19 +1842,19 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
}
}
@@ -1685,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, 10)
+ const rcvBufSz = 10
+ c.CreateConnected(789, 30000, rcvBufSz)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1693,11 +1871,16 @@ func TestFullWindowReceive(t *testing.T) {
_, _, err := c.EP.Read(nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
- // Fill up the window.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
+ // the provided buffer value by tcp.SegOverheadFactor to calculate the actual
+ // receive buffer size.
+ data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
+ for i := range data {
+ data[i] = byte(i % 255)
+ }
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -1718,17 +1901,17 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
// Receive data and check it.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -1737,17 +1920,17 @@ func TestFullWindowReceive(t *testing.T) {
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
}
// Check that we get an ACK for the newly non-zero window.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(10),
+ checker.TCPWindow(10),
),
)
}
@@ -1756,28 +1939,32 @@ func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Start off with a window size of 10, then shrink it to 5.
- c.CreateConnected(789, 30000, 10)
-
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err)
- }
+ // Start off with a certain receive buffer then cut it in half and verify that
+ // the right edge of the window does not shrink.
+ // NOTE: Netstack doubles the value specified here.
+ rcvBufSize := 65536
+ iss := seqnum.Value(789)
+ // Enable window scaling with a scale of zero from our end.
+ c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
-
- // Send 3 bytes, check that the peer acknowledges them.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
- c.SendPacket(data[:3], &context.Headers{
+ // Send a 1 byte payload so that we can record the current receive window.
+ // Send a payload of half the size of rcvBufSize.
+ seqNum := iss.Add(1)
+ payload := []byte{1}
+ c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1789,50 +1976,97 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("Timed out waiting for data to arrive")
}
- // Check that data is acknowledged, and that window doesn't go to zero
- // just yet because it was previously set to 10. It must go to 7 now.
- checker.IPv4(t, c.GetPacket(),
+ // Read the 1 byte payload we just sent.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ if got, want := payload, v; !bytes.Equal(got, want) {
+ t.Fatalf("got data: %v, want: %v", got, want)
+ }
+
+ seqNum = seqNum.Add(1)
+ // Verify that the ACK does not shrink the window.
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(7),
),
)
+ // Stash the initial window.
+ initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd))
+ // Now shrink the receive buffer to half its original size.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
+ }
- // Send 7 more bytes, check that the window fills up.
- c.SendPacket(data[3:], &context.Headers{
+ data := generateRandomPayload(t, rcvBufSize)
+ // Send a payload of half the size of rcvBufSize.
+ c.SendPacket(data[:rcvBufSize/2], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
+ // Verify that the ACK does not shrink the window.
+ pkt = c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd))
+ if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
+ t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
}
+ // Send another payload of half the size of rcvBufSize. This should fill up the
+ // socket receive buffer and we should see a zero window.
+ c.SendPacket(data[rcvBufSize/2:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
// Receive data and check it.
- read := make([]byte, 0, 10)
+ read := make([]byte, 0, rcvBufSize)
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1842,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("got data = %v, want = %v", read, data)
}
- // Check that we get an ACK for the newly non-zero window, which is the
- // new size.
+ // Check that we get an ACK for the newly non-zero window, which is the new
+ // receive buffer size we set after the connection was established.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(5),
+ checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
),
)
}
@@ -1866,7 +2100,7 @@ func TestSimpleSend(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received.
@@ -1875,8 +2109,8 @@ func TestSimpleSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -1908,7 +2142,7 @@ func TestZeroWindowSend(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check if we got a zero-window probe.
@@ -1917,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -1939,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -1976,19 +2210,19 @@ func TestScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2008,7 +2242,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2018,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2036,39 +2270,40 @@ func TestScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
- c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2081,19 +2316,19 @@ func TestScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2109,21 +2344,21 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
@@ -2135,14 +2370,14 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2155,7 +2390,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2165,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2180,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Set the window size such that a window scale of 4 will be used.
- const wnd = 65535 * 10
- const ws = uint32(4)
- c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
+ // Set the buffer size such that a window scale of 5 will be used.
+ const bufSz = 65535 * 10
+ const ws = uint32(5)
+ c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
// Write chunks of 50000 bytes.
- remain := wnd
+ remain := 0
sent := 0
data := make([]byte, 50000)
- for remain > len(data) {
+ // Keep writing till the window drops below len(data).
+ for {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -2201,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(remain>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Don't reduce window to zero here.
+ if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
+ remain = wnd << ws
+ break
+ }
}
// Make the window non-zero, but the scaled window zero.
- if remain >= 16 {
+ for remain >= 16 {
data = data[:remain-15]
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
@@ -2226,25 +2466,38 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(0),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Since the receive buffer is split between window advertisement and
+ // application data buffer the window does not always reflect the space
+ // available and actual space available can be a bit more than what is
+ // advertised in the window.
+ wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
+ if wnd == 0 {
+ break
+ }
+ remain = wnd << ws
}
- // Read at least 1MSS of data. An ack should be sent in response to that.
+ // Read at least 2MSS of data. An ack should be sent in response to that.
+ // Since buffer space is now split in half between window and application
+ // data we need to read more than 1 MSS(65536) of data for a non-zero window
+ // update to be sent. For 1MSS worth of window to be available we need to
+ // read at least 128KB. Since our segments above were 50KB each it means
+ // we need to read at 3 packets.
sz := 0
- for sz < defaultMTU {
+ for sz < defaultMTU*2 {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
sz += len(v)
}
@@ -2253,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(sz>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2311,7 +2564,7 @@ func TestSegmentMerging(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2322,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize+1),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2345,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+11),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+11),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2381,7 +2634,7 @@ func TestDelay(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2393,8 +2646,8 @@ func TestDelay(t *testing.T) {
checker.PayloadLen(len(want)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2428,7 +2681,7 @@ func TestUndelay(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2440,8 +2693,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2463,8 +2716,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2512,7 +2765,7 @@ func TestMSSNotDelayed(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2525,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2563,7 +2816,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received in chunks.
@@ -2577,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2631,7 +2884,7 @@ func TestSetTTL(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
@@ -2639,7 +2892,7 @@ func TestSetTTL(t *testing.T) {
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %s", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -2671,7 +2924,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
@@ -2683,11 +2936,11 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2698,14 +2951,14 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2725,8 +2978,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ opt := tcpip.TCPSynRcvdCountThresholdOption(0)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create EP and start listening.
@@ -2753,12 +3007,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2794,7 +3048,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
select {
case err := <-ch:
if err != nil {
- t.Fatalf("Error creating endpoint: %v", err)
+ t.Fatalf("Error creating endpoint: %s", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2813,13 +3067,13 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- const wndScale = 2
+ const wndScale = 3
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
@@ -2830,7 +3084,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -2854,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.SrcPort(tcpHdr.SourcePort()),
- checker.SeqNum(tcpHdr.SequenceNumber()),
+ checker.TCPSeqNum(tcpHdr.SequenceNumber()),
checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -2875,16 +3129,16 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-ch:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Connect failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2899,22 +3153,22 @@ func TestCloseListener(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Close the listener and measure how long it takes.
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -2950,22 +3204,25 @@ loop:
case tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -2990,7 +3247,7 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
@@ -3001,8 +3258,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
defer c.Cleanup()
const numRetries = 2
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil {
- t.Fatalf("could not set protocol option MaxRetries.\n")
+ opt := tcpip.TCPMaxRetriesOption(numRetries)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3013,7 +3271,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Expect first transmit and MaxRetries retransmits.
@@ -3048,7 +3306,10 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -3058,15 +3319,16 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err)
+ opt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -3089,6 +3351,63 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -3097,15 +3416,15 @@ func TestFinImmediately(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3125,8 +3444,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3140,15 +3459,15 @@ func TestFinRetransmit(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3158,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3179,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3195,7 +3514,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3203,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3221,15 +3540,15 @@ func TestFinWithNoPendingData(t *testing.T) {
// Shutdown, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3250,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3268,7 +3587,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
}
@@ -3278,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3290,15 +3609,15 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// because the congestion window doesn't allow it. Wait until a
// retransmit is received.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3317,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3338,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3354,7 +3673,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3362,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3380,15 +3699,15 @@ func TestFinWithPendingData(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3396,15 +3715,15 @@ func TestFinWithPendingData(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3424,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3441,7 +3760,7 @@ func TestFinWithPartialAck(t *testing.T) {
// FIN from the test side.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3449,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3470,23 +3789,23 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3494,15 +3813,15 @@ func TestFinWithPartialAck(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3540,20 +3859,20 @@ func TestUpdateListenBacklog(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Update the backlog with another Listen() on the same endpoint.
if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %v", err)
+ t.Fatalf("Listen failed to update backlog: %s", err)
}
ep.Close()
@@ -3585,7 +3904,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that only data that fits in the scaled window is sent.
@@ -3593,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3631,18 +3950,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
})
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
}
// Ensure there were no errors during handshake. If these stats have
// incremented, then the connection should not have been established.
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
}
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
}
}
@@ -3666,10 +3985,10 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c.SendSegment(vv)
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -3732,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3759,8 +4078,9 @@ func TestReadAfterClosedState(t *testing.T) {
// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
// after 1 second in TIME_WAIT state.
tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -3770,7 +4090,7 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
@@ -3782,14 +4102,14 @@ func TestReadAfterClosedState(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Send some data and acknowledge the FIN.
@@ -3807,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(uint32(791+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(uint32(791+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3818,7 +4138,7 @@ func TestReadAfterClosedState(t *testing.T) {
time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Wait for receive to be notified.
@@ -3853,11 +4173,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3871,66 +4191,84 @@ func TestReusePort(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Second case, an endpoint that was bound and is connecting..
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Third case, an endpoint that was bound and is listening.
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
}
@@ -3939,11 +4277,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got receive buffer size = %v, want = %v", s, v)
+ t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
}
@@ -3952,24 +4290,24 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got send buffer size = %v, want = %v", s, v)
+ t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}
func TestDefaultBufferSizes(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer func() {
if ep != nil {
@@ -3981,28 +4319,42 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20,
+ }
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30,
+ }
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
@@ -4011,34 +4363,40 @@ func TestDefaultBufferSizes(t *testing.T) {
func TestMinMaxBufferSizes(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
- // Set values below the min.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
+ // Set values below the min/2.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
@@ -4049,28 +4407,30 @@ func TestMinMaxBufferSizes(t *testing.T) {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
+ // Values above max are capped at max and then doubled.
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
+ // Values above max are capped at max and then doubled.
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
}
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}})
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
+ t.Errorf("CreateNIC failed: %s", err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -4094,16 +4454,15 @@ func TestBindToDeviceOption(t *testing.T) {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
+ t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
+ } else if bindToDevice != testAction.getBindToDevice {
+ t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
}
})
}
@@ -4111,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) {
func makeStack() (*stack.Stack, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
- ipv6.NewProtocol(),
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
id := loopback.New()
@@ -4166,12 +4525,12 @@ func TestSelfConnect(t *testing.T) {
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -4180,12 +4539,12 @@ func TestSelfConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
<-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("Connect failed: %v", err)
+ if err := ep.LastError(); err != nil {
+ t.Fatalf("Connect failed: %s", err)
}
// Write something.
@@ -4193,7 +4552,7 @@ func TestSelfConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Read back what was written.
@@ -4202,12 +4561,12 @@ func TestSelfConnect(t *testing.T) {
rd, _, err := ep.Read(nil)
if err != nil {
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
<-notifyCh
rd, _, err = ep.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
}
@@ -4291,7 +4650,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps = append(eps, ep)
switch network {
@@ -4342,7 +4701,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %v", i, err)
+ t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
want := tcpip.ErrConnectStarted
@@ -4350,7 +4709,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
want = tcpip.ErrNoPortAvailable
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
}
})
}
@@ -4384,7 +4743,7 @@ func TestPathMTUDiscovery(t *testing.T) {
}
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
@@ -4398,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(seqNum),
- checker.AckNum(790),
+ checker.TCPSeqNum(seqNum),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4487,11 +4846,11 @@ func TestStackSetCongestionControl(t *testing.T) {
var oldCC tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
@@ -4523,12 +4882,12 @@ func TestStackAvailableCongestionControl(t *testing.T) {
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcpip.AvailableCongestionControlOption
+ var aCC tcpip.TCPAvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want)
}
}
@@ -4539,18 +4898,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) {
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcpip.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.TCPAvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcpip.AvailableCongestionControlOption
+ var cc tcpip.TCPAvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err)
}
- if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want)
}
}
@@ -4574,25 +4933,25 @@ func TestEndpointSetCongestionControl(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
var oldCC tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err)
}
if connected {
c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
}
- if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
+ t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err)
}
got, want := cc, oldCC
@@ -4604,7 +4963,7 @@ func TestEndpointSetCongestionControl(t *testing.T) {
want = tc.cc
}
if got != want {
- t.Fatalf("got congestion control: %v, want: %v", got, want)
+ t.Fatalf("got congestion control = %+v, want = %+v", got, want)
}
})
}
@@ -4614,8 +4973,8 @@ func TestEndpointSetCongestionControl(t *testing.T) {
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -4625,11 +4984,23 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ const keepAliveIdle = 100 * time.Millisecond
const keepAliveInterval = 3 * time.Second
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle)
+ if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err)
+ }
+ keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval)
+ if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err)
+ }
c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
- c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+ if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
+ t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
+ t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
+ }
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -4638,8 +5009,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4657,14 +5028,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -4672,8 +5043,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4684,8 +5055,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
@@ -4710,8 +5081,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next-1)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(next-1)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4737,26 +5108,30 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
@@ -4775,7 +5150,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -4801,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
}
func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
@@ -4819,7 +5195,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -4854,23 +5230,24 @@ func TestListenBacklogFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
- listenBacklog := 2
+ listenBacklog := 10
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
- for i := 0; i < listenBacklog; i++ {
- executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */)
+ lastPortOffset := uint16(0)
+ for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
time.Sleep(50 * time.Millisecond)
@@ -4878,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now execute send one more SYN. The stack should not respond as the backlog
// is full at this point.
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 2,
+ SrcPort: context.TestPort + uint16(lastPortOffset),
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: seqnum.Value(789),
@@ -4892,14 +5269,14 @@ func TestListenBacklogFull(t *testing.T) {
defer c.WQ.EventUnregister(&we)
for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4909,7 +5286,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -4919,16 +5296,16 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4942,7 +5319,7 @@ func TestListenBacklogFull(t *testing.T) {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -4951,6 +5328,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -4958,53 +5337,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5045,7 +5430,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5053,8 +5438,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
tests := []struct {
name string
@@ -5104,11 +5489,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5149,7 +5530,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5162,19 +5543,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 1
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
@@ -5197,7 +5578,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5233,14 +5614,14 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5254,7 +5635,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5262,8 +5643,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ opt := tcpip.TCPSynRcvdCountThresholdOption(1)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create TCP endpoint.
@@ -5309,14 +5691,14 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5325,7 +5707,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -5374,7 +5756,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5395,8 +5777,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.AckNum(uint32(irs) + 1),
- checker.SeqNum(uint32(iss + 1)),
+ checker.TCPAckNum(uint32(irs) + 1),
+ checker.TCPSeqNum(uint32(iss + 1)),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5414,7 +5796,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err != nil && err != tcpip.ErrWouldBlock {
t.Fatalf("Accept failed: %s", err)
@@ -5429,7 +5811,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5450,7 +5832,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
pkt := c.GetPacket()
tcpHdr = header.TCP(header.IPv4(pkt).Payload())
if string(tcpHdr.Payload()) != data {
- t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
}
}
@@ -5460,20 +5842,20 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
stats := c.Stack().Stats()
@@ -5487,14 +5869,14 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5503,7 +5885,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
}
if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -5514,14 +5896,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
srcPort := uint16(context.TestPort)
@@ -5546,10 +5928,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
time.Sleep(50 * time.Millisecond)
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -5557,14 +5939,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5579,28 +5961,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
+ t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
}
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -5610,14 +5992,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- aep, _, err = ep.Accept()
+ aep, _, err = ep.Accept(nil)
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5625,25 +6007,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
}
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
ep.Close()
// Give worker goroutines time to receive the close notification.
time.Sleep(1 * time.Second)
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Accepted endpoint remains open when the listen endpoint is closed.
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
}
@@ -5663,13 +6045,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
// Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -5688,16 +6076,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
time.Sleep(latency)
rawEP.SendPacketWithTS([]byte{1}, tsVal)
- // Verify that the ACK has the expected window.
- wantRcvWnd := receiveBufferSize
- wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
- rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
- b := make([]byte, int(receiveBufferSize)*2)
- offset := 0
- payloadSize := receiveBufferSize - 1
+ payloadSize := receiveBufferSize * 2
+ b := make([]byte, int(payloadSize))
+
worker := (c.EP).(interface {
StopWork()
ResumeWork()
@@ -5706,11 +6092,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Stop the worker goroutine.
worker.StopWork()
- start := offset
- end := offset + payloadSize
+ start := 0
+ end := payloadSize / 2
packetsSent := 0
for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetEnd := start + mss
+ if start+mss > end {
+ packetEnd = end
+ }
+ rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
packetsSent++
}
@@ -5718,29 +6108,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// are waiting to be read.
worker.ResumeWork()
- // Since we read no bytes the window should goto zero till the
- // application reads some of the data.
- // Discard all intermediate acks except the last one.
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
- }
+ // Since we sent almost the full receive buffer worth of data (some may have
+ // been dropped due to segment overheads), we should get a zero window back.
+ pkt = c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(pkt).Payload())
+ gotRcvWnd := tcpHdr.WindowSize()
+ wantAckNum := tcpHdr.AckNumber()
+ if got, want := int(gotRcvWnd), 0; got != want {
+ t.Fatalf("got rcvWnd: %d, want: %d", got, want)
}
- rawEP.VerifyACKRcvWnd(0)
time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when window is closed is dropped and
- // not acked.
+ // Verify that sending more data when receiveBuffer is exhausted.
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- // Verify that the stack sends us back an ACK with the sequence number
- // of the last packet sent indicating it was dropped.
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
- checker.Window(0),
- ))
-
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
@@ -5753,23 +6134,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Verify that we receive a non-zero window update ACK. When running
// under thread santizer this test can end up sending more than 1
// ack, 1 for the non-zero window
- p = c.GetPacket()
+ p := c.GetPacket()
checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.TCPAckNum(uint32(wantAckNum)),
func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
- if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
+ // We use 10% here as the error margin upwards as the initial window we
+ // got was afer 1 segment was already in the receive buffer queue.
+ tolerance := 1.1
+ if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
+ t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
}
},
))
}
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
+// This test verifies that the advertised window is auto-tuned up as the
+// application is reading the data that is being received.
func TestReceiveBufferAutoTuning(t *testing.T) {
const mtu = 1500
const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
@@ -5779,26 +6163,33 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Enable Auto-tuning.
stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
// Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
-
- wantRcvWnd := receiveBufferSize
+ tsVal := uint32(rawEP.TSVal)
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
return uint16(rcvWnd >> uint16(c.WindowScale))
}
@@ -5815,14 +6206,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
StopWork()
ResumeWork()
})
- tsVal := rawEP.TSVal
- // We are going to do our own computation of what the moderated receive
- // buffer should be based on sent/copied data per RTT and verify that
- // the advertised window by the stack matches our calculations.
- prevCopied := 0
- done := false
latency := 1 * time.Millisecond
- for i := 0; !done; i++ {
+ for i := 0; i < 5; i++ {
tsVal++
// Stop the worker goroutine.
@@ -5844,15 +6229,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Give 1ms for the worker to process the packets.
time.Sleep(1 * time.Millisecond)
- // Verify that the advertised window on the ACK is reduced by
- // the total bytes sent.
- expectedWnd := wantRcvWnd - totalSent
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
}
+ lastACK = pkt
+ }
+ if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
+ t.Fatalf("advertised window got: %d, want <= %d", got, want)
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
// Now read all the data from the endpoint and invoke the
// moderation API to allow for receive buffer auto-tuning
@@ -5882,30 +6272,28 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
- prevCopied = totalCopied
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- rttCopied := totalCopied
- if i == 1 {
- // The moderation code accumulates copied bytes till
- // RTT is established. So add in the bytes sent in
- // the first iteration to the total bytes for this
- // RTT.
- rttCopied += prevCopied
- // Now reset it to the initial value used by the
- // auto tuning logic.
- prevCopied = tcp.InitialCwnd * mss * 2
+ // Read loop above could generate an ACK if the window had dropped to
+ // zero and then read had opened it up.
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
+ }
+ lastACK = pkt
}
- newWnd := rttCopied<<1 + 16*mss
- grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
- newWnd += (grow << 1)
- if newWnd > maxReceiveBufferSize {
- newWnd = maxReceiveBufferSize
- done = true
+ curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
+ // If thew new current window is close maxReceiveBufferSize then terminate
+ // the loop. This can happen before all iterations are done due to timing
+ // differences when running the test.
+ if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
+ break
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
- wantRcvWnd = newWnd
- prevCopied = rttCopied
// Increase the latency after first two iterations to
// establish a low RTT value in the receiver since it
// only tracks the lowest value. This ensures that when
@@ -5918,6 +6306,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
offset += payloadSize
payloadSize *= 2
}
+ // Check that at the end of our iterations the receive window grew close to the maximum
+ // permissible size of maxReceiveBufferSize/2
+ if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
+ t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
+ }
+
}
func TestDelayEnabled(t *testing.T) {
@@ -5926,7 +6320,7 @@ func TestDelayEnabled(t *testing.T) {
checkDelayOption(t, c, false, false) // Delay is disabled by default.
for _, v := range []struct {
- delayEnabled tcp.DelayEnabled
+ delayEnabled tcpip.TCPDelayEnabled
wantDelayOption bool
}{
{delayEnabled: false, wantDelayOption: false},
@@ -5934,19 +6328,19 @@ func TestDelayEnabled(t *testing.T) {
} {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err)
}
checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
}
}
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) {
t.Helper()
- var gotDelayEnabled tcp.DelayEnabled
+ var gotDelayEnabled tcpip.TCPDelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
if gotDelayEnabled != wantDelayEnabled {
t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
@@ -5954,7 +6348,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
@@ -5976,24 +6370,27 @@ func TestTCPLingerTimeout(t *testing.T) {
tcpLingerTimeout time.Duration
want time.Duration
}{
- {"NegativeLingerTimeout", -123123, 0},
- {"ZeroLingerTimeout", 0, 0},
+ {"NegativeLingerTimeout", -123123, -1},
+ // Zero is treated same as the stack's default TCP_LINGER2 timeout.
+ {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
{"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
// Values > stack's TCPLingerTimeout are capped to the stack's
// value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
- t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)
+ if err := c.EP.SetSockOpt(&v); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err)
}
- var v tcpip.TCPLingerTimeoutOption
+
+ v = 0
if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ t.Fatalf("GetSockOpt(&%T) = %s", v, err)
}
if got, want := time.Duration(v), tc.want; got != want {
- t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ t.Fatalf("got linger timeout = %s, want = %s", got, want)
}
})
}
@@ -6047,12 +6444,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6066,8 +6463,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6084,8 +6481,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now send a RST and this should be ignored and not
@@ -6113,8 +6510,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6166,12 +6563,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6185,8 +6582,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6203,8 +6600,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Out of order ACK should generate an immediate ACK in
@@ -6220,8 +6617,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6273,12 +6670,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6292,8 +6689,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6310,8 +6707,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Send a SYN request w/ sequence number lower than
@@ -6328,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+ // drain any older notifications from the notification channel before attempting
+ // 2nd connection.
+ select {
+ case <-ch:
+ default:
+ }
+
// Send a SYN request w/ sequence number higher than
// the highest sequence number sent.
iss = seqnum.Value(792)
@@ -6356,12 +6760,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.SendPacket(nil, ackHeaders)
// Try to accept the connection.
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6379,8 +6783,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
}
want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
@@ -6429,12 +6834,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6448,8 +6853,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6466,8 +6871,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
time.Sleep(2 * time.Second)
@@ -6481,8 +6886,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Sleep for 4 seconds so at this point we are 1 second past the
@@ -6510,15 +6915,15 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -6529,8 +6934,9 @@ func TestTCPCloseWithData(t *testing.T) {
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
}
wq := &waiter.Queue{}
@@ -6578,12 +6984,12 @@ func TestTCPCloseWithData(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6609,8 +7015,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now write a few bytes and then close the endpoint.
@@ -6628,8 +7034,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -6643,8 +7049,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.AckNum(uint32(iss+2)),
+ checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.TCPAckNum(uint32(iss+2)),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
// First send a partial ACK.
@@ -6689,8 +7095,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -6710,12 +7116,15 @@ func TestTCPUserTimeout(t *testing.T) {
// expired.
initRTO := 1 * time.Second
userTimeout := initRTO / 2
- c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+ v := tcpip.TCPUserTimeoutOption(userTimeout)
+ if err := c.EP.SetSockOpt(&v); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err)
+ }
// Send some data and wait before ACKing it.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -6723,8 +7132,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -6758,18 +7167,21 @@ func TestTCPUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6781,22 +7193,35 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+ const keepAliveIdle = 100 * time.Millisecond
const keepAliveInterval = 3 * time.Second
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10)
- c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+ keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle)
+ if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err)
+ }
+ keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval)
+ if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err)
+ }
+ if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
+ t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
+ t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
+ }
// Set userTimeout to be the duration to be 1 keepalive
// probes. Which means that after the first probe is sent
// the second one should cause the connection to be
// closed due to userTimeout being hit.
- userTimeout := 1 * keepAliveInterval
- c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+ userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval)
+ if err := c.EP.SetSockOpt(&userTimeout); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err)
+ }
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Now receive 1 keepalives, but don't ACK it.
@@ -6804,8 +7229,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -6830,23 +7255,26 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
-func TestIncreaseWindowOnReceive(t *testing.T) {
+func TestIncreaseWindowOnRead(t *testing.T) {
// This test ensures that the endpoint sends an ack,
- // after recv() when the window grows to more than 1 MSS.
+ // after read() when the window grows by more than 1 MSS.
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -6855,10 +7283,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -6871,46 +7298,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Break once the window drops below defaultMTU/2
+ if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
+ break
+ }
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
- // We now have < 1 MSS in the buffer space. Read the data! An
- // ack should be sent in response to that. The window was not
- // zero, but it grew to larger than MSS.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
- }
-
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ // We now have < 1 MSS in the buffer space. Read at least > 2 MSS
+ // worth of data as receive buffer space
+ read := 0
+ // defaultMTU is a good enough estimate for the MSS used for this
+ // connection.
+ for read < defaultMTU*2 {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ read += len(v)
}
- // After reading two packets, we surely crossed MSS. See the ack:
+ // After reading > MSS worth of data, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -6930,7 +7354,6 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
remain := rcvBuf
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -6943,39 +7366,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowLessThanEq(0xffff),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
- // After reading two packets, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -6996,14 +7409,15 @@ func TestTCPDeferAccept(t *testing.T) {
}
const tcpDeferAccept = 1 * time.Second
- if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
+ if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7019,14 +7433,14 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give a bit of time for the socket to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7034,8 +7448,8 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestTCPDeferAcceptTimeout(t *testing.T) {
@@ -7053,14 +7467,15 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
}
const tcpDeferAccept = 1 * time.Second
- if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
+ if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7071,7 +7486,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
// Send data. This should result in an acceptable endpoint.
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
@@ -7087,14 +7502,14 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give sometime for the endpoint to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7103,8 +7518,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestResetDuringClose(t *testing.T) {
@@ -7129,8 +7544,8 @@ func TestResetDuringClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(irs.Add(1))),
- checker.AckNum(uint32(iss.Add(5)))))
+ checker.TCPSeqNum(uint32(irs.Add(1))),
+ checker.TCPAckNum(uint32(iss.Add(5)))))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
@@ -7160,3 +7575,65 @@ func TestResetDuringClose(t *testing.T) {
wg.Wait()
}
+
+func TestStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
+ }
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+}
+
+func TestSetStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ testCases := []struct {
+ v int
+ err *tcpip.Error
+ }{
+ {int(tcpip.TCPTimeWaitReuseDisabled), nil},
+ {int(tcpip.TCPTimeWaitReuseGlobal), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ }
+
+ for _, tc := range testCases {
+ opt := tcpip.TCPTimeWaitReuseOption(tc.v)
+ err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)
+ if got, want := err, tc.err; got != want {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err)
+ }
+ if tc.err != nil {
+ continue
+ }
+
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
+ }
+
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+ }
+}
+
+// generateRandomPayload generates a random byte slice of the specified length
+// causing a fatal test failure if it is unable to do so.
+func generateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 8edbff964..0f9ed06cd 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
defer c.Cleanup()
if cookieEnabled {
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -158,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
@@ -180,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -192,8 +194,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
defer c.Cleanup()
if cookieEnabled {
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -217,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(false, 0, 0),
),
@@ -235,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of
+ // that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 7b1d72cf4..79646fefe 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -68,11 +68,23 @@ const (
// V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
- // testInitialSequenceNumber is the initial sequence number sent in packets that
+ // TestInitialSequenceNumber is the initial sequence number sent in packets that
// are sent in response to a SYN or in the initial SYN sent to the stack.
- testInitialSequenceNumber = 789
+ TestInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -133,30 +145,39 @@ type Context struct {
// WindowScale is the expected window scale in SYN packets sent by
// the stack.
WindowScale uint8
+
+ // RcvdWindowScale is the actual window scale sent by the stack in
+ // SYN/SYN-ACK.
+ RcvdWindowScale uint8
}
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
// retransmit in case the test executors are overloaded and cause timers
// to fire earlier than expected.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
- t.Fatalf("failed to set stack-wide minRTO: %s", err)
+ minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
}
// Some of the congestion control tests send up to 640 packets, we so
@@ -179,12 +200,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -202,7 +231,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
@@ -236,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) {
c.CheckNoPacketTimeout(errMsg, 1*time.Second)
}
-// GetPacket reads a packet from the link layer endpoint and verifies
+// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
// that it is an IPv4 packet with the expected source and destination
-// addresses. It will fail with an error if no packet is received for
-// 2 seconds.
-func (c *Context) GetPacket() []byte {
+// addresses. If no packet is received in the specified timeout it will return
+// nil.
+func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
c.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
- c.t.Fatalf("Packet wasn't written out")
return nil
}
@@ -255,8 +283,16 @@ func (c *Context) GetPacket() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
@@ -266,6 +302,21 @@ func (c *Context) GetPacket() []byte {
return b
}
+// GetPacket reads a packet from the link layer endpoint and verifies
+// that it is an IPv4 packet with the expected source and destination
+// addresses.
+func (c *Context) GetPacket() []byte {
+ c.t.Helper()
+
+ p := c.GetPacketWithTimeout(5 * time.Second)
+ if p == nil {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ return p
+}
+
// GetPacketNonBlocking reads a packet from the link layer endpoint
// and verifies that it is an IPv4 packet with the expected source
// and destination address. If no packet is available it will return
@@ -282,15 +333,23 @@ func (c *Context) GetPacketNonBlocking() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
-func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
+func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
@@ -314,11 +373,15 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
const icmpv4VariableHeaderOffset = 4
copy(icmp[icmpv4VariableHeaderOffset:], p1)
copy(icmp[header.ICMPv4PayloadOffset:], p2)
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
@@ -372,26 +435,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: s,
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: c.BuildSegment(payload, h),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendAck sends an ACK packet.
@@ -441,8 +507,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -468,8 +534,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -512,9 +578,8 @@ func (c *Context) GetV6Packet() []byte {
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
@@ -564,9 +629,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
}
// CreateConnected creates a connected TCP endpoint.
@@ -607,6 +673,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
+ synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
c.SendPacket(nil, &Headers{
@@ -624,15 +691,15 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-notifyCh:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ if err := c.EP.LastError(); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -642,6 +709,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ c.RcvdWindowScale = uint8(synOpts.WS)
c.Port = tcpHdr.SourcePort()
}
@@ -713,17 +781,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
}
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
+// the provided tsVal as well as returns the original packet.
+func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
+ r.C.t.Helper()
// Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPTimestampChecker(true, 0, tsVal),
),
)
@@ -731,19 +800,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
opts := tcpSeg.ParsedOptions()
r.RecentTS = opts.TSVal
+ return ackPacket
+}
+
+// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
+// tsVal.
+func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+ r.C.t.Helper()
+ _ = r.VerifyAndReturnACKWithTS(tsVal)
}
// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
// matches the provided rcvWnd.
func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ r.C.t.Helper()
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.Window(rcvWnd),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
+ checker.TCPWindow(rcvWnd),
),
)
}
@@ -762,8 +840,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPSACKBlockChecker(sackBlocks),
),
)
@@ -837,7 +915,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -855,8 +933,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
tcpCheckers := []checker.TransportChecker{
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS) + 1),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPSeqNum(uint32(c.IRS) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
}
// Verify that tsEcr of ACK packet is wantOptions.TSVal if the
@@ -876,8 +954,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -892,7 +969,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Mark in context that timestamp option is enabled for this endpoint.
c.TimeStampEnabled = true
-
+ c.RcvdWindowScale = uint8(synOptions.WS)
return &RawEndpoint{
C: c,
SrcPort: tcpSeg.DestinationPort(),
@@ -943,12 +1020,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
c.t.Fatalf("Accept failed: %v", err)
}
@@ -985,6 +1062,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
offset += header.EncodeMSSOption(uint32(maxPayload), opts)
@@ -1009,7 +1087,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
offset += paddingToAdd
// Send a SYN request.
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: StackPort,
@@ -1023,13 +1101,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// are present.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
+ rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
c.IRS = seqnum.Value(tcp.SequenceNumber())
tcpCheckers := []checker.TransportChecker{
checker.SrcPort(StackPort),
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
}
@@ -1072,6 +1151,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// Send ACK.
c.SendPacket(nil, ackHeaders)
+ c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
c.Port = StackPort
return &RawEndpoint{
@@ -1091,7 +1171,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
// for the Stack in the context.
func (c *Context) SACKEnabled() bool {
- var v tcp.SACKEnabled
+ var v tcpip.TCPSACKEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return false
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index c70525f27..7981d469b 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
+ *t = timer{}
}
// checkExpiration checks if the given timer has actually expired, it should be
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 12bc1b5b5..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index b5d2d0ba6..c78549424 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 647b2067a..cdb5127ab 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,9 @@
package udp
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -93,6 +96,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
+ sndBufSizeMax int
state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
@@ -102,9 +106,10 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
- reusePort bool
+ portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
+ noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -134,7 +139,7 @@ type endpoint struct {
// multicastMemberships that need to be remvoed when the endpoint is
// closed. Protected by the mu mutex.
- multicastMemberships []multicastMembership
+ multicastMemberships map[multicastMembership]struct{}
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -149,6 +154,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// +stateify savable
@@ -158,7 +166,7 @@ type multicastMembership struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+ e := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -177,13 +185,27 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// TTL=1.
//
// Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastLoop: true,
- rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
- state: StateInitial,
- uniqueID: s.UniqueID(),
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ state: StateInitial,
+ uniqueID: s.UniqueID(),
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
}
+
+ return e
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -191,7 +213,7 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
-func (e *endpoint) takeLastError() *tcpip.Error {
+func (e *endpoint) LastError() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
@@ -213,16 +235,16 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
}
- for _, mem := range e.multicastMemberships {
+ for mem := range e.multicastMemberships {
e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
- e.multicastMemberships = nil
+ e.multicastMemberships = make(map[multicastMembership]struct{})
// Close the receive list and drain it.
e.rcvMu.Lock()
@@ -247,15 +269,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if err := e.takeLastError(); err != nil {
+ if err := e.LastError(); err != nil {
return buffer.View{}, tcpip.ControlMessages{}, err
}
@@ -398,7 +415,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- if err := e.takeLastError(); err != nil {
+ if err := e.LastError(); err != nil {
return 0, nil, err
}
@@ -430,24 +447,33 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
var route *stack.Route
+ var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
route = &e.route
dstPort = e.dstPort
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route, given that it may need to
- // change in Route.Resolve() call below.
+ resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
if e.state != StateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
+ err = tcpip.ErrInvalidEndpointState
+ }
+ if err == nil && route.IsResolutionRequired() {
+ ch, err = route.Resolve(waker)
}
+
+ e.mu.Unlock()
+ e.mu.RLock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
+ }
+ return
}
} else {
// Reject destination address if it goes through a different
@@ -461,10 +487,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- if to.Addr == header.IPv4Broadcast && !e.broadcast {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -478,10 +500,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
route = &r
dstPort = dst.Port
+ resolve = route.Resolve
+ }
+
+ if !e.broadcast && route.IsOutboundBroadcast() {
+ return 0, nil, tcpip.ErrBroadcastDisabled
}
if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
+ if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -507,7 +534,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -531,6 +558,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -552,10 +584,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
case tcpip.ReusePortOption:
e.mu.Lock()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.mu.Unlock()
case tcpip.V6OnlyOption:
@@ -581,6 +616,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -602,17 +644,52 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
+ e.mu.Unlock()
+ return nil
case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
}
return nil
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case tcpip.MulticastInterfaceOption:
+ case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
defer e.mu.Unlock()
@@ -648,7 +725,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastNICID = nic
e.multicastAddr = addr
- case tcpip.AddMembershipOption:
+ case *tcpip.AddMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
return tcpip.ErrInvalidOptionValue
}
@@ -679,19 +756,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- for _, mem := range e.multicastMemberships {
- if mem == memToInsert {
- return tcpip.ErrPortInUse
- }
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return tcpip.ErrPortInUse
}
if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
- e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+ e.multicastMemberships[memToInsert] = struct{}{}
- case tcpip.RemoveMembershipOption:
+ case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
return tcpip.ErrInvalidOptionValue
}
@@ -713,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
- memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
- for i, mem := range e.multicastMemberships {
- if mem == memToRemove {
- memToRemoveIndex = i
- break
- }
- }
- if memToRemoveIndex == -1 {
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
return tcpip.ErrBadLocalAddress
}
@@ -732,17 +800,24 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return err
}
- e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ delete(e.multicastMemberships, memToRemove)
- case tcpip.BindToDeviceOption:
- id := tcpip.NICID(v)
+ case *tcpip.BindToDeviceOption:
+ id := tcpip.NICID(*v)
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
}
return nil
}
@@ -765,6 +840,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -789,11 +870,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
case tcpip.ReuseAddressOption:
- return false, nil
+ e.mu.RLock()
+ v := e.portFlags.MostRecent
+ e.mu.RUnlock()
+
+ return v, nil
case tcpip.ReusePortOption:
e.mu.RLock()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.mu.RUnlock()
return v, nil
@@ -810,6 +895,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.AcceptConnOption:
+ return false, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -830,6 +918,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
v := int(e.multicastTTL)
@@ -848,7 +940,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -870,10 +962,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case tcpip.ErrorOption:
- return e.takeLastError()
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
@@ -887,6 +977,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
+ case *tcpip.LingerOption:
+ e.mu.RLock()
+ *o = e.linger
+ e.mu.RUnlock()
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -895,22 +990,30 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
- // Allocate a buffer for the UDP header.
- hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: data,
+ })
+ pkt.Owner = owner
- // Initialize the header.
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
@@ -921,12 +1024,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{
- Header: hdr,
- Data: data,
- TransportHeader: buffer.View(udp),
- Owner: owner,
- }); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: ProtocolNumber,
+ TTL: ttl,
+ TOS: tos,
+ }, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
@@ -958,6 +1060,11 @@ func (e *endpoint) Disconnect() *tcpip.Error {
id stack.TransportEndpointID
btd tcpip.NICID
)
+
+ // We change this value below and we need the old value to unregister
+ // the endpoint.
+ boundPortFlags := e.boundPortFlags
+
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -970,16 +1077,17 @@ func (e *endpoint) Disconnect() *tcpip.Error {
return err
}
e.state = StateBound
+ boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
@@ -1051,6 +1159,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
+ oldPortFlags := e.boundPortFlags
+
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
@@ -1058,7 +1168,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
@@ -1116,28 +1226,23 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- // FIXME(b/129164367): Support SO_REUSEADDR.
- MostRecent: false,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
return id, e.bindToDevice, err
}
- e.boundPortFlags = flags
id.LocalPort = port
}
+ e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
return id, e.bindToDevice, err
@@ -1264,27 +1369,54 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.rcvMu.Unlock()
}
+ e.lastErrorMu.Lock()
+ hasError := e.lastError != nil
+ e.lastErrorMu.Unlock()
+ if hasError {
+ result |= waiter.EventErr
+ }
return result
}
+// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
+// On IPv4, UDP checksum is optional, and a zero value means the transmitter
+// omitted the checksum generation (RFC768).
+// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool {
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ return hdr.CalculateChecksum(xsum) == 0xffff
+ }
+ return true
+}
+
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
+ hdr := header.UDP(pkt.TransportHeader().View())
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- pkt.Data.TrimFront(header.UDPMinimumSize)
+ if !verifyChecksum(r, hdr, pkt) {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1317,15 +1449,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Save any useful information from the network header to the packet.
switch r.NetProto {
case header.IPv4ProtocolNumber:
- packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.RemoteAddress
- packet.packetInfo.NIC = r.NICID()
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
case header.IPv6ProtocolNumber:
- packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
+ // address. packetInfo.LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.LocalAddress
+ packet.packetInfo.NIC = r.NICID()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
@@ -1336,17 +1471,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
- defer e.mu.RUnlock()
-
if e.state == StateConnected {
e.lastErrorMu.Lock()
- defer e.lastErrorMu.Unlock()
-
e.lastError = tcpip.ErrConnectionRefused
+ e.lastErrorMu.Unlock()
+ e.mu.RUnlock()
+
+ e.waiterQueue.Notify(waiter.EventErr)
+ return
}
+ e.mu.RUnlock()
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 851e6b635..858c99a45 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- for _, m := range e.multicastMemberships {
+ for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a674ceb68..3ae6cc221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
@@ -61,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- pkt stack.PacketBuffer
+ pkt *stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -81,7 +81,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.ID = r.id
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
+ ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto}
ep.RegisterNICID = r.route.NICID()
+ ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 52af6de22..da5b1deb2 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package udp contains the implementation of the UDP transport protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.NewProtocol() as one of the
-// transport protocols when calling stack.New(). Then endpoints can be created
-// by passing udp.ProtocolNumber as the transport protocol number when calling
-// Stack.NewEndpoint().
+// Package udp contains the implementation of the UDP transport protocol.
package udp
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/waiter"
@@ -32,9 +28,25 @@ import (
const (
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
)
-type protocol struct{}
+type protocol struct {
+ stack *stack.Stack
+}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -42,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid udp packet size.
@@ -64,134 +76,30 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
-// HandleUnknownDestinationPacket handles packets targeted at this protocol but
-// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
- // Get the header then trim it from the view.
- h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- // Malformed packet.
+// HandleUnknownDestinationPacket handles packets that are targeted at this
+// protocol but don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ hdr := header.UDP(pkt.TransportHeader().View())
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- if int(header.UDP(h).Length()) > pkt.Data.Size() {
- // Malformed packet.
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
-
- // Only send ICMP error if the address is not a multicast/broadcast
- // v4/v6 address or the source is not the unspecified address.
- //
- // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
- return true
+ return stack.UnknownDestinationPacketMalformed
}
- // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
- // Unreachable messages with code:
- //
- // 2 (Protocol Unreachable), when the designated transport protocol
- // is not supported; or
- //
- // 3 (Port Unreachable), when the designated transport protocol
- // (e.g., UDP) is unable to demultiplex the datagram but has no
- // protocol mechanism to inform the sender.
- switch len(id.LocalAddress) {
- case header.IPv4AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
- return true
- }
- // As per RFC 1812 Section 4.3.2.3
- //
- // ICMP datagram SHOULD contain as much of the original
- // datagram as possible without the length of the ICMP
- // datagram exceeding 576 bytes
- //
- // NOTE: The above RFC referenced is different from the original
- // recommendation in RFC 1122 where it mentioned that at least 8
- // bytes of the payload must be included. Today linux and other
- // systems implement the] RFC1812 definition and not the original
- // RFC 1122 requirement.
- mtu := int(r.MTU())
- if mtu > header.IPv4MinimumProcessableDatagramSize {
- mtu = header.IPv4MinimumProcessableDatagramSize
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
- available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
-
- // The buffers used by pkt may be used elsewhere in the system.
- // For example, a raw or packet socket may use what UDP
- // considers an unreachable destination. Thus we deep copy pkt
- // to prevent multiple ownership and SR errors.
- newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- payload := newNetHeader.ToVectorisedView()
- payload.Append(pkt.Data.ToView().ToVectorisedView())
- payload.CapLength(payloadLen)
-
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4DstUnreachable)
- pkt.SetCode(header.ICMPv4PortUnreachable)
- pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
-
- case header.IPv6AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
- return true
- }
-
- // As per RFC 4443 section 2.4
- //
- // (c) Every ICMPv6 error message (type < 128) MUST include
- // as much of the IPv6 offending (invoking) packet (the
- // packet that caused the error) as possible without making
- // the error message packet exceed the minimum IPv6 MTU
- // [IPv6].
- mtu := int(r.MTU())
- if mtu > header.IPv6MinimumMTU {
- mtu = header.IPv6MinimumMTU
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
- available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
- payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader})
- payload.Append(pkt.Data)
- payload.CapLength(payloadLen)
-
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
- pkt.SetType(header.ICMPv6DstUnreachable)
- pkt.SetCode(header.ICMPv6PortUnreachable)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
+ if !verifyChecksum(r, hdr, pkt) {
+ r.Stack().Stats().UDP.ChecksumErrors.Increment()
+ return stack.UnknownDestinationPacketMalformed
}
- return true
+
+ return stack.UnknownDestinationPacketUnhandled
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -201,7 +109,12 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ return parse.UDP(pkt)
+}
+
// NewProtocol returns a UDP transport protocol.
-func NewProtocol() stack.TransportProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8acaa607a..fb7738dda 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -276,8 +294,8 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
}
@@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -370,8 +388,12 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
+ c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
h := flow.header4Tuple(outgoing)
checkers = append(
@@ -385,23 +407,44 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
}
// injectPacket creates a packet of the given flow and with the given payload,
-// and injects it into the link endpoint.
-func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+// and injects it into the link endpoint. If badChecksum is true, the packet has
+// a bad checksum in the UDP header.
+func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) {
c.t.Helper()
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ if badChecksum {
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+ }
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ if badChecksum {
+ // Invalidate the UDP header checksum field (Unlike IPv4, zero is
+ // a valid checksum value for IPv6 so no need to avoid it).
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+ }
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -420,16 +463,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -439,19 +476,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -485,13 +515,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
+ return buf
}
func newPayload() []byte {
@@ -508,18 +532,18 @@ func newMinPayload(minSize int) []byte {
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}})
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
opts := stack.NICOptions{Name: "my_device"}
if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -543,16 +567,15 @@ func TestBindToDeviceOption(t *testing.T) {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
+ t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
+ } else if bindToDevice != testAction.getBindToDevice {
+ t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
}
})
}
@@ -566,7 +589,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
c.t.Helper()
payload := newPayload()
- c.injectPacket(flow, payload)
+ c.injectPacket(flow, payload, false)
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
@@ -608,12 +631,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
+ c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
- c.t.Fatalf("bad payload: got %x, want %x", v, payload)
+ c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
@@ -647,7 +670,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -658,40 +681,40 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
}
}
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
// above, since the endpoint is dual-stack.
if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -701,11 +724,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -718,7 +741,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -733,7 +756,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -748,7 +771,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -763,7 +786,7 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -784,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: tt.handleLocal,
})
defer c.cleanup()
@@ -800,14 +823,17 @@ func TestV4ReadSelfSource(t *testing.T) {
h := unicastV4.header4Tuple(incoming)
h.srcAddr = h.dstAddr
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
}
})
}
@@ -821,7 +847,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -848,8 +874,8 @@ func TestReadOnBoundToMulticast(t *testing.T) {
// Join multicast group.
ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
}
// Check that we receive multicast packets but not unicast or broadcast
@@ -884,6 +910,24 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -959,7 +1003,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -1009,7 +1053,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -1026,7 +1070,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -1047,7 +1091,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1074,7 +1118,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1089,7 +1133,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1103,7 +1147,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1238,7 +1282,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1249,6 +1293,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
}
}
+func TestReadIPPacketInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedLocalAddr tcpip.Address
+ expectedDestAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedLocalAddr: stackAddr,
+ expectedDestAddr: stackAddr,
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastAddr,
+ expectedDestAddr: multicastAddr,
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: broadcastAddr,
+ expectedDestAddr: broadcastAddr,
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedLocalAddr: stackV6Addr,
+ expectedDestAddr: stackV6Addr,
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastV6Addr,
+ expectedDestAddr: multicastV6Addr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
+ t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
+ }
+
+ testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: 1,
+ LocalAddr: test.expectedLocalAddr,
+ DestinationAddr: test.expectedDestAddr,
+ }))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1263,6 +1406,56 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1280,19 +1473,19 @@ func TestTTL(t *testing.T) {
if flow.isMulticast() {
wantTTL = multicastTTL
} else {
- var p stack.NetworkProtocol
+ var p stack.NetworkProtocolFactory
+ var n tcpip.NetworkProtocolNumber
if flow.isV4() {
- p = ipv4.NewProtocol()
+ p = ipv4.NewProtocol
+ n = ipv4.ProtocolNumber
} else {
- p = ipv6.NewProtocol()
- }
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- if err != nil {
- t.Fatal(err)
+ p = ipv6.NewProtocol
+ n = ipv6.ProtocolNumber
}
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{p},
+ })
+ ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil)
wantTTL = ep.DefaultTTL()
ep.Close()
}
@@ -1316,21 +1509,6 @@ func TestSetTTL(t *testing.T) {
c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
- var p stack.NetworkProtocol
- if flow.isV4() {
- p = ipv4.NewProtocol()
- } else {
- p = ipv6.NewProtocol()
- }
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- if err != nil {
- t.Fatal(err)
- }
- ep.Close()
-
testWrite(c, flow, checker.TTL(wantTTL))
})
}
@@ -1353,7 +1531,7 @@ func TestSetTOS(t *testing.T) {
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
}
if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
@@ -1510,23 +1688,21 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
}
// Verify multicast interface addr and NIC were set correctly.
// Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
+ c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
+ } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
+ c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
}
})
}
@@ -1550,21 +1726,33 @@ func TestV4UnknownDestination(t *testing.T) {
// so that the final generated IPv4 packet is larger than
// header.IPv4MinimumProcessableDatagramSize.
largePayload bool
+ // badChecksum if true, will set an invalid checksum in the
+ // header.
+ badChecksum bool
}{
- {unicastV4, true, false},
- {unicastV4, true, true},
- {multicastV4, false, false},
- {multicastV4, false, true},
- {broadcast, false, false},
- {broadcast, false, true},
- }
+ {unicastV4, true, false, false},
+ {unicastV4, true, true, false},
+ {unicastV4, false, false, true},
+ {unicastV4, false, true, true},
+ {multicastV4, false, false, false},
+ {multicastV4, false, true, false},
+ {broadcast, false, false, false},
+ {broadcast, false, true, false},
+ }
+ checksumErrors := uint64(0)
for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
payload := newPayload()
if tc.largePayload {
payload = newMinPayload(576)
}
- c.injectPacket(tc.flow, payload)
+ c.injectPacket(tc.flow, payload, tc.badChecksum)
+ if tc.badChecksum {
+ checksumErrors++
+ if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
+ t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ }
if !tc.icmpRequired {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -1583,9 +1771,8 @@ func TestV4UnknownDestination(t *testing.T) {
return
}
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1595,16 +1782,25 @@ func TestV4UnknownDestination(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4DstUnreachable),
checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ // We need to compare the included data part of the UDP packet that is in
+ // the ICMP packet with the matching original data.
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
wantLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ // To work out the data size we need to simulate what the sender would
+ // have done. The wanted size is the total available minus the sum of
+ // the headers in the UDP AND ICMP packets, given that we know the test
+ // had only a minimal IP header but the ICMP sender will have allowed
+ // for a maximally sized packet header.
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
- // In case of large payloads the IP packet may be truncated. Update
+ // In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ // Add back the two headers within the payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
if got, want := len(origDgram.Payload()), wantLen; got != want {
@@ -1630,19 +1826,31 @@ func TestV6UnknownDestination(t *testing.T) {
// largePayload if true will result in a payload large enough to
// create an IPv6 packet > header.IPv6MinimumMTU bytes.
largePayload bool
+ // badChecksum if true, will set an invalid checksum in the
+ // header.
+ badChecksum bool
}{
- {unicastV6, true, false},
- {unicastV6, true, true},
- {multicastV6, false, false},
- {multicastV6, false, true},
- }
+ {unicastV6, true, false, false},
+ {unicastV6, true, true, false},
+ {unicastV6, false, false, true},
+ {unicastV6, false, true, true},
+ {multicastV6, false, false, false},
+ {multicastV6, false, true, false},
+ }
+ checksumErrors := uint64(0)
for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
payload := newPayload()
if tc.largePayload {
payload = newMinPayload(1280)
}
- c.injectPacket(tc.flow, payload)
+ c.injectPacket(tc.flow, payload, tc.badChecksum)
+ if tc.badChecksum {
+ checksumErrors++
+ if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
+ t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ }
if !tc.icmpRequired {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -1661,9 +1869,8 @@ func TestV6UnknownDestination(t *testing.T) {
return
}
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1695,7 +1902,7 @@ func TestV6UnknownDestination(t *testing.T) {
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1703,20 +1910,228 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
- var want uint64 = 1
+ // Invalidate the UDP header length field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+}
+
+// TestShortHeader verifies that when a packet with a too-short UDP header is
+// received, the malformed received global stat gets incremented.
+func TestShortHeader(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ h := unicastV6.header4Tuple(incoming)
+
+ // Allocate a buffer for an IPv6 and too-short UDP header.
+ const udpSize = header.UDPMinimumSize - 1
+ buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: header.UDPMinimumSize,
+ })
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+ // Copy all but the last byte of the UDP header into the packet.
+ copy(buf[header.IPv6MinimumSize:], udpHdr)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
+// TestBadChecksumErrors verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestBadChecksumErrors(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(flow.sockProto())
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload, true /* badChecksum */)
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
}
}
@@ -1730,15 +2145,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1761,11 +2176,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
@@ -1807,3 +2222,193 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn
c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const nicID1 = 1
+
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ requiresBroadcastOpt bool
+ }{
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ requiresBroadcastOpt: false,
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ // TODO(gvisor.dev/issue/3938): Once we support marking a route as
+ // broadcast, this test should require the broadcast option to be set.
+ requiresBroadcastOpt: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
+ }
+ defer ep.Close()
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ to := tcpip.FullAddress{
+ Addr: test.remoteAddr,
+ Port: 80,
+ }
+ opts := tcpip.WriteOptions{To: &to}
+ expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ if !test.requiresBroadcastOpt {
+ expectedErrWithoutBcastOpt = nil
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+ })
+ }
+}