summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go68
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go12
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go213
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go199
-rw-r--r--pkg/tcpip/transport/tcp/BUILD19
-rw-r--r--pkg/tcpip/transport/tcp/accept.go166
-rw-r--r--pkg/tcpip/transport/tcp/connect.go130
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go150
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go302
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go100
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go139
-rw-r--r--pkg/tcpip/transport/tcp/rack.go82
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment.go43
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/snd.go105
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go32
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go803
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go40
-rw-r--r--pkg/tcpip/transport/tcp/timer.go1
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go259
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go7
-rw-r--r--pkg/tcpip/transport/udp/protocol.go85
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go815
32 files changed, 2906 insertions, 1082 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9ce625c17..7e5c79776 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b1d820372..bd6f49eb8 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,7 +111,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -140,11 +141,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -348,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+ }
return nil
}
@@ -430,9 +430,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ })
+ pkt.Owner = owner
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
copy(icmpv4, data)
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
@@ -447,15 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+ pkt.Data = data.ToVectorisedView()
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: data.ToVectorisedView(),
- TransportHeader: buffer.View(icmpv4),
- Owner: owner,
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -463,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ })
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
copy(icmpv6, data)
// Set the ident. Sequence number is provided by the user.
icmpv6.SetIdent(ident)
@@ -477,15 +479,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
dataVV := data.ToVectorisedView()
icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
+ pkt.Data = dataVV
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: dataVV,
- TransportHeader: buffer.View(icmpv6),
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
// checkV4MappedLocked determines the effective network protocol and converts
@@ -511,6 +510,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
@@ -611,14 +611,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -743,19 +743,23 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -789,12 +793,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
},
}
- packet.data = pkt.Data
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader().View().ToVectorisedView()
+ packet.data.Append(pkt.Data)
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -805,7 +811,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 3c47692b2..74ef6541e 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -124,6 +124,16 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
+ //
+ // Right now, the Parse() method is tied to enabled protocols passed into
+ // stack.New. This works for UDP and TCP, but we handle ICMP traffic even
+ // when netstack users don't pass ICMP as a supported protocol.
+ return false
+}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 23158173d..1b03ad6bb 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,6 +25,8 @@
package packet
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -43,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -71,11 +76,17 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
sndBufSize: 32 * 1024,
}
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
return nil, err
}
@@ -132,13 +154,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (stack.IPTables, error) {
- return ep.stack.IPTables(), nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -163,11 +180,20 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- // TODO(b/129292371): Implement.
+ // TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -220,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -233,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
@@ -269,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
@@ -279,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (ep *endpoint) takeLastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return ep.takeLastError()
+ }
return tcpip.ErrNotSupported
}
@@ -294,11 +381,36 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -320,48 +432,73 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(b/129292371): Return network protocol.
- if len(pkt.LinkHeader) > 0 {
+ // TODO(gvisor.dev/issue/173): Return network protocol.
+ if !pkt.LinkHeader().View().IsEmpty() {
// Get info directly from the ethernet header.
- hdr := header.Ethernet(pkt.LinkHeader)
+ hdr := header.Ethernet(pkt.LinkHeader().View())
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header.
+ var combinedVV buffer.VectorisedView
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if pkt.LinkHeader().View().IsEmpty() {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
} else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data)
- packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eee754a5a..edc2b5b61 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,8 @@
package raw
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -61,21 +63,23 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
rcvList rawPacketList
- rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
@@ -91,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -103,8 +107,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
}
// Unassociated endpoints are write-only and users call Write() with IP
@@ -166,17 +182,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -206,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -249,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -310,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -345,28 +351,26 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
- break
+ if e.hdrIncluded {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
+ return 0, nil, err
}
-
- hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- Owner: e.owner,
- }); err != nil {
+ } else {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ pkt.Owner = e.owner
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: e.TransProto,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -391,11 +395,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -461,14 +460,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -518,17 +511,69 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -548,6 +593,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.KeepaliveEnabledOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -568,7 +619,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -584,11 +635,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -632,15 +690,26 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
},
}
- networkHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- combinedVV := networkHeader.ToVectorisedView()
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
-
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f38eb6833..234fb95ce 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -40,6 +40,8 @@ go_library(
"endpoint_state.go",
"forwarder.go",
"protocol.go",
+ "rack.go",
+ "rack_state.go",
"rcv.go",
"rcv_state.go",
"reno.go",
@@ -49,6 +51,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
@@ -76,20 +79,18 @@ go_library(
)
go_test(
- name = "tcp_test",
+ name = "tcp_x_test",
size = "medium",
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
+ "tcp_rack_test.go",
"tcp_sack_test.go",
"tcp_test.go",
"tcp_timestamp_test.go",
],
- # FIXME(b/68809571)
- tags = [
- "flaky",
- ],
+ shard_count = 10,
deps = [
":tcp",
"//pkg/sync",
@@ -119,3 +120,11 @@ go_test(
"//pkg/tcpip/seqnum",
],
)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e6a23c978..913ea6535 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -198,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments. The endpoint is returned
-// with n.mu held.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -221,32 +220,12 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
- // Create sender and receiver.
- //
- // The receiver at least temporarily has a zero receive window scale,
- // but the caller may change it (before starting the protocol loop).
- n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- // Lock the endpoint before registering to ensure that no out of
- // band changes are possible due to incoming packets etc till
- // the endpoint is done initializing.
- n.mu.Lock()
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
- n.mu.Unlock()
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
- return n, nil
+ return n
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -257,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
- if err != nil {
- return nil, err
- }
+ ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ ep.mu.Lock()
ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
@@ -268,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
+
l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
@@ -288,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// to the newly created endpoint.
l.listenEP.propagateInheritableOptionsLocked(ep)
+ if !ep.reserveTupleLocked() {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
+ // Register new endpoint so that packets are routed to it.
+ if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
+
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+
+ ep.isRegistered = true
+
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ ep.notifyAborted()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -378,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
// Precondition: e.mu and n.mu must be held.
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
+ n.portFlags = e.portFlags
+ n.boundBindToDevice = e.boundBindToDevice
+ n.boundPortFlags = e.boundPortFlags
+}
+
+// reserveTupleLocked reserves an accepted endpoint's tuple.
+//
+// Preconditions:
+// * propagateInheritableOptionsLocked has been called.
+// * e.mu is held.
+func (e *endpoint) reserveTupleLocked() bool {
+ dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ if !e.stack.ReserveTuple(
+ e.effectiveNetProtos,
+ ProtocolNumber,
+ e.ID.LocalAddress,
+ e.ID.LocalPort,
+ e.boundPortFlags,
+ e.boundBindToDevice,
+ dest,
+ ) {
+ return false
+ }
+
+ e.isPortReserved = true
+ e.boundDest = dest
+ return true
+}
+
+// notifyAborted wakes up any waiters on registered, but not accepted
+// endpoints.
+//
+// This is strictly not required normally as a socket that was never accepted
+// can't really have any registered waiters except when stack.Wait() is called
+// which waits for all registered endpoints to stop and expects an EventHUp.
+func (e *endpoint) notifyAborted() {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -485,7 +521,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(timeStampOffset()),
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
MSS: mssForRoute(&s.route),
}
@@ -534,6 +570,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ iss := s.ackNumber - 1
+ irs := s.sequenceNumber - 1
+
// Since SYN cookies are in use this is potentially an ACK to a
// SYN-ACK we sent but don't have a half open connection state
// as cookies are being used to protect against a potential SYN
@@ -544,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// when under a potential syn flood attack.
//
// Validate the cookie.
- data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -569,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
- if err != nil {
+ n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+
+ n.mu.Lock()
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
+ if !n.reserveTupleLocked() {
+ n.mu.Unlock()
+ n.Close()
+
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
- // Propagate any inheritable options from the listening endpoint
- // to the newly created endpoint.
- e.propagateInheritableOptionsLocked(n)
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ n.isRegistered = true
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
@@ -587,10 +644,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
- // We do not use transitionToStateEstablishedLocked here as there is
- // no handshake state available when doing a SYN cookie based accept.
n.isConnectNotified = true
- n.setEndpointState(StateEstablished)
+ n.transitionToStateEstablishedLocked(&handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ })
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index a7e088d4e..290172ac9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error {
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, func() {
- resendWaker.Assert()
- })
+ rt := time.AfterFunc(timeOut, resendWaker.Assert)
defer rt.Stop()
// Set up the wakers.
@@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -742,11 +746,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
optLen := len(tf.opts)
- hdr := &pkt.Header
- packetSize := pkt.Data.Size()
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
- pkt.TransportHeader = buffer.View(tcp)
+ tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
SrcPort: tf.id.LocalPort,
DstPort: tf.id.RemotePort,
@@ -758,8 +758,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
})
copy(tcp[header.TCPMinimumSize:], tf.opts)
- length := uint16(hdr.UsedLength() + packetSize)
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size()))
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
// This is called CHECKSUM_PARTIAL in the Linux kernel. We
@@ -797,17 +796,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
packetSize = size
}
size -= packetSize
- var pkt stack.PacketBuffer
- pkt.Header = buffer.NewPrependable(hdrSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrSize,
+ })
pkt.Hash = tf.txHash
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
data.ReadToVV(&pkt.Data, packetSize)
- buildTCPHdr(r, tf, &pkt, gso)
+ buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
- pkts.PushBack(&pkt)
+ pkts.PushBack(pkt)
}
if tf.ttl == 0 {
@@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
return sendTCPBatch(r, tf, data, gso, owner)
}
- pkt := stack.PacketBuffer{
- Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
- Data: data,
- Hash: tf.txHash,
- Owner: owner,
- }
- buildTCPHdr(r, tf, &pkt, gso)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen,
+ Data: data,
+ })
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ buildTCPHdr(r, tf, pkt, gso)
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
@@ -995,24 +995,22 @@ func (e *endpoint) completeWorkerLocked() {
// transitionToStateEstablisedLocked transitions a given endpoint
// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver if required.
+// It also initializes sender/receiver.
func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- if e.snd == nil {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- }
- if e.rcv == nil {
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvAutoParams.prevCopied = int(h.rcvWnd)
- e.rcvListMu.Unlock()
- }
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+
e.setEndpointState(StateEstablished)
}
@@ -1022,14 +1020,19 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.EndpointState() == StateClose {
+ s := e.EndpointState()
+ if s == StateClose {
return
}
+
+ if s.connected() {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+ }
+
// Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
e.setEndpointState(StateClose)
- e.stack.Stats().TCP.CurrentConnected.Decrement()
- e.stack.Stats().TCP.EstablishedClosed.Increment()
}
// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
@@ -1052,8 +1055,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
panic("current endpoint not removed from demuxer, enqueing segments to itself")
}
- if ep.(*endpoint).enqueueSegment(s) {
- ep.(*endpoint).newSegmentWaker.Assert()
+ if ep := ep.(*endpoint); ep.enqueueSegment(s) {
+ ep.newSegmentWaker.Assert()
}
}
@@ -1122,7 +1125,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ if e.EndpointState().closed() {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1159,13 +1162,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
return nil
}
-// handleSegment handles a given segment and notifies the worker goroutine if
-// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
- // Invoke the tcp probe if installed.
+func (e *endpoint) probeSegment() {
if e.probe != nil {
e.probe(e.completeState())
}
+}
+
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed. The tcp probe function will update
+ // the TCPEndpointState after the segment is processed.
+ defer e.probeSegment()
if s.flagIsSet(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
@@ -1347,6 +1355,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.setEndpointState(StateError)
e.HardError = err
+ e.workerCleanup = true
// Lock released below.
epilogue()
return err
@@ -1441,9 +1450,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
- closeWaker.Assert()
- })
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
}
@@ -1461,7 +1468,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return err
}
}
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ if !e.EndpointState().closed() {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1517,6 +1524,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1526,7 +1534,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
loop:
- for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ for {
+ switch e.EndpointState() {
+ case StateTimeWait, StateClose, StateError:
+ break loop
+ }
+
e.mu.Unlock()
v, _ := s.Fetch(true)
e.mu.Lock()
@@ -1569,11 +1582,14 @@ loop:
reuseTW = e.doTimeWait()
}
- // Mark endpoint as closed.
- if e.EndpointState() != StateError {
- e.transitionToStateCloseLocked()
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
}
+ e.transitionToStateCloseLocked()
+
// Lock released below.
epilogue()
@@ -1686,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 || n&notifyAbort != 0 {
+ if n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 6062ca916..98aecab9e 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -15,6 +15,8 @@
package tcp
import (
+ "encoding/binary"
+
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -66,89 +68,68 @@ func (q *epQueue) empty() bool {
// processor is responsible for processing packets queued to a tcp endpoint.
type processor struct {
epQ epQueue
+ sleeper sleep.Sleeper
newEndpointWaker sleep.Waker
closeWaker sleep.Waker
- id int
- wg sync.WaitGroup
-}
-
-func newProcessor(id int) *processor {
- p := &processor{
- id: id,
- }
- p.wg.Add(1)
- go p.handleSegments()
- return p
}
func (p *processor) close() {
p.closeWaker.Assert()
}
-func (p *processor) wait() {
- p.wg.Wait()
-}
-
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
p.newEndpointWaker.Assert()
}
-func (p *processor) handleSegments() {
- const newEndpointWaker = 1
- const closeWaker = 2
- s := sleep.Sleeper{}
- s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
- s.AddWaker(&p.closeWaker, closeWaker)
- defer s.Done()
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
for {
- id, ok := s.Fetch(true)
- if ok && id == closeWaker {
- p.wg.Done()
- return
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
}
- for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
if ep.segmentQueue.empty() {
continue
}
- // If socket has transitioned out of connected state
- // then just let the worker handle the packet.
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
//
- // NOTE: We read this outside of e.mu lock which means
- // that by the time we get to handleSegments the
- // endpoint may not be in ESTABLISHED. But this should
- // be fine as all normal shutdown states are handled by
- // handleSegments and if the endpoint moves to a
- // CLOSED/ERROR state then handleSegments is a noop.
- if ep.EndpointState() != StateEstablished {
- ep.newSegmentWaker.Assert()
- continue
- }
-
- if !ep.mu.TryLock() {
- ep.newSegmentWaker.Assert()
- continue
- }
- // If the endpoint is in a connected state then we do
- // direct delivery to ensure low latency and avoid
- // scheduler interactions.
- if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
- // Send any active resets if required.
- if err != nil {
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
}
- ep.notifyProtocolGoroutine(notifyTickleWorker)
ep.mu.Unlock()
- continue
- }
-
- if !ep.segmentQueue.empty() {
- p.epQ.enqueue(ep)
+ } else {
+ ep.newSegmentWaker.Assert()
}
-
- ep.mu.Unlock()
}
}
}
@@ -159,34 +140,39 @@ func (p *processor) handleSegments() {
// hash of the endpoint id to ensure that delivery for the same endpoint happens
// in-order.
type dispatcher struct {
- processors []*processor
+ processors []processor
seed uint32
-}
-
-func newDispatcher(nProcessors int) *dispatcher {
- processors := []*processor{}
- for i := 0; i < nProcessors; i++ {
- processors = append(processors, newProcessor(i))
- }
- return &dispatcher{
- processors: processors,
- seed: generateRandUint32(),
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
}
}
func (d *dispatcher) close() {
- for _, p := range d.processors {
- p.close()
+ for i := range d.processors {
+ d.processors[i].close()
}
}
func (d *dispatcher) wait() {
- for _, p := range d.processors {
- p.wait()
- }
+ d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
if !s.parse() {
@@ -231,20 +217,18 @@ func generateRandUint32() uint32 {
if _, err := rand.Read(b); err != nil {
panic(err)
}
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return binary.LittleEndian.Uint32(b)
}
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
- payload := []byte{
- byte(id.LocalPort),
- byte(id.LocalPort >> 8),
- byte(id.RemotePort),
- byte(id.RemotePort >> 8)}
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
h := jenkins.Sum32(d.seed)
- h.Write(payload)
+ h.Write(payload[:])
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
- return d.processors[h.Sum32()%uint32(len(d.processors))]
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5ba972f1..d08cfe0ff 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,7 +63,8 @@ const (
StateClosing
)
-// connected is the set of states where an endpoint is connected to a peer.
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
func (s EndpointState) connected() bool {
switch s {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
@@ -73,6 +74,40 @@ func (s EndpointState) connected() bool {
}
}
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// String implements fmt.Stringer.String.
func (s EndpointState) String() string {
switch s {
@@ -361,7 +396,8 @@ type endpoint struct {
mu sync.Mutex `state:"nosave"`
ownedByUser uint32
- // state must be read/set using the EndpointState()/setEndpointState() methods.
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -370,8 +406,8 @@ type endpoint struct {
origEndpointState EndpointState `state:"nosave"`
isPortReserved bool `state:"manual"`
- isRegistered bool
- boundNICID tcpip.NICID `state:"manual"`
+ isRegistered bool `state:"manual"`
+ boundNICID tcpip.NICID
route stack.Route `state:"manual"`
ttl uint8
v6only bool
@@ -380,10 +416,14 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // portFlags stores the current values of port related flags.
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
+ boundDest tcpip.FullAddress
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -391,7 +431,7 @@ type endpoint struct {
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
- effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -409,10 +449,11 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
- //
- // recentTS must be read/written atomically.
recentTS uint32
+ // recentTSTime is the unix time when we updated recentTS last.
+ recentTSTime time.Time `state:".(unixTime)"`
+
// tsOffset is a randomized offset added to the value of the
// TSVal field in the timestamp option.
tsOffset uint32
@@ -427,9 +468,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // reusePort is set to true if SO_REUSEPORT is enabled.
- reusePort bool
-
// bindToDevice is set to the NIC on which to bind or disabled if 0.
bindToDevice tcpip.NICID
@@ -449,7 +487,6 @@ type endpoint struct {
// The options below aren't implemented, but we remember the user
// settings because applications expect to be able to set/query these
// options.
- reuseAddr bool
// slowAck holds the negated state of quick ack. It is stubbed out and
// does nothing.
@@ -759,15 +796,15 @@ func (e *endpoint) EndpointState() EndpointState {
return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
}
-// setRecentTimestamp atomically sets the recentTS field to the
-// provided value.
+// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- atomic.StoreUint32(&e.recentTS, recentTS)
+ e.recentTS = recentTS
+ e.recentTSTime = time.Now()
}
-// recentTimestamp atomically reads and returns the value of the recentTS field.
+// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return atomic.LoadUint32(&e.recentTS)
+ return e.recentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -799,7 +836,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSize: DefaultReceiveBufferSize,
sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
- reuseAddr: true,
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -867,7 +903,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case StateClose, StateError:
+ case StateClose, StateError, StateTimeWait:
// Ready for anything.
result = mask
@@ -986,14 +1022,15 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
}
// Mark endpoint as closed.
@@ -1051,16 +1088,17 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
}
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
e.route.Release()
e.stack.CompleteTransportEndpointCleanup(e)
@@ -1172,14 +1210,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
+ defer e.UnlockUser()
+
+ // When in SYN-SENT state, let the caller block on the receive.
+ // An application can initiate a non-blocking connect and then block
+ // on a receive. It can expect to read any data after the handshake
+ // is complete. RFC793, section 3.9, p58.
+ if e.EndpointState() == StateSynSent {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1189,7 +1240,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1199,7 +1249,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
- e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1486,12 +1535,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
case tcpip.ReuseAddressOption:
e.LockUser()
- e.reuseAddr = v
+ e.portFlags.TupleOnly = v
e.UnlockUser()
case tcpip.ReusePortOption:
e.LockUser()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.UnlockUser()
case tcpip.V6OnlyOption:
@@ -1549,6 +1598,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1722,15 +1778,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Same as effectively disabling TCPLinger timeout.
v = 0
}
- var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
- if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
- // We were unable to retrieve a stack config, just use
- // the DefaultTCPLingerTimeout.
- if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
- stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
- }
- }
- // Cap it to the stack wide TCPLinger timeout.
+ // Cap it to MaxTCPLingerTimeout.
+ stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
if v > stkTCPLingerTimeout {
v = stkTCPLingerTimeout
}
@@ -1745,6 +1794,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.deferAccept = time.Duration(v)
e.UnlockUser()
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
default:
return nil
}
@@ -1795,14 +1847,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.ReuseAddressOption:
e.LockUser()
- v := e.reuseAddr
+ v := e.portFlags.TupleOnly
e.UnlockUser()
return v, nil
case tcpip.ReusePortOption:
e.LockUser()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.UnlockUser()
return v, nil
@@ -1819,6 +1871,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.MulticastLoopOption:
+ return true, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1853,6 +1908,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := header.TCPDefaultMSS
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1886,6 +1946,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.UnlockUser()
return v, nil
+ case tcpip.MulticastTTLOption:
+ return 1, nil
+
default:
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1895,11 +1958,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -1952,6 +2011,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
e.UnlockUser()
+ case *tcpip.OriginalDestinationOption:
+ ipt := e.stack.IPTables()
+ addr, port, err := ipt.OriginalDst(e.ID)
+ if err != nil {
+ return err
+ }
+ *o = tcpip.OriginalDestinationOption{
+ Addr: addr,
+ Port: port,
+ }
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2049,8 +2119,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
defer r.Release()
- origID := e.ID
-
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.ID.LocalAddress = r.LocalAddress
e.ID.RemoteAddress = r.RemoteAddress
@@ -2058,7 +2126,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2081,43 +2149,91 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
h.Write(portBuf)
portOffset := h.Sum32()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err))
+ }
+
+ reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal
+ if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
+ switch netProto {
+ case header.IPv4ProtocolNumber:
+ reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ case header.IPv6ProtocolNumber:
+ reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ }
+ }
+
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- // reusePort is false below because connect cannot reuse a port even if
- // reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
- return false, nil
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ if err != tcpip.ErrPortInUse || !reuse {
+ return false, nil
+ }
+ transEPID := e.ID
+ transEPID.LocalPort = p
+ // Check if an endpoint is registered with demuxer in TIME-WAIT and if
+ // we can reuse it. If we can't find a transport endpoint then we just
+ // skip using this port as it's possible that either an endpoint has
+ // bound the port but not registered with demuxer yet (no listen/connect
+ // done yet) or the reservation was freed between the check above and
+ // the FindTransportEndpoint below. But rather than retry the same port
+ // we just skip it and move on.
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ if transEP == nil {
+ // ReservePort failed but there is no registered endpoint with
+ // demuxer. Which indicates there is at least some endpoint that has
+ // bound the port.
+ return false, nil
+ }
+
+ tcpEP := transEP.(*endpoint)
+ tcpEP.LockUser()
+ // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
+ // less than 1 second has elapsed since its recentTS was updated then
+ // we cannot reuse the port.
+ if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ tcpEP.UnlockUser()
+ return false, nil
+ }
+ // Since the endpoint is in TIME-WAIT it should be safe to acquire its
+ // Lock while holding the lock for this endpoint as endpoints in
+ // TIME-WAIT do not acquire locks on other endpoints.
+ tcpEP.workerCleanup = false
+ tcpEP.cleanupLocked()
+ tcpEP.notifyProtocolGoroutine(notifyAbort)
+ tcpEP.UnlockUser()
+ // Now try and Reserve again if it fails then we skip.
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ return false, nil
+ }
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
- case nil:
- // Port picking successful. Save the details of
- // the selected port.
- e.ID = id
- e.boundBindToDevice = e.bindToDevice
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err == tcpip.ErrPortInUse {
+ return false, nil
+ }
return false, err
}
+
+ // Port picking successful. Save the details of
+ // the selected port.
+ e.ID = id
+ e.isPortReserved = true
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.boundDest = addr
+ return true, nil
}); err != nil {
return err
}
}
- // Remove the port reservation. This can happen when Bind is called
- // before Connect: in such a case we don't want to hold on to
- // reservations anymore.
- if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
- e.isPortReserved = false
- }
-
e.isRegistered = true
e.setEndpointState(StateConnecting)
e.route = r.Clone()
@@ -2296,7 +2412,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
@@ -2388,16 +2504,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return err
}
e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = flags
+ e.boundPortFlags = e.portFlags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
@@ -2405,7 +2518,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Any failures beyond this point must remove the port registration.
defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
@@ -2428,6 +2541,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
+ if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ return err
+ }
+
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2462,7 +2579,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -2481,7 +2598,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2492,6 +2609,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
@@ -2611,15 +2740,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.tsOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(offset uint32) uint32 {
- now := time.Now()
- return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
+ return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
@@ -2762,6 +2890,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
WEst: cubic.wEst,
}
}
+
+ rc := e.snd.rc
+ s.Sender.RACKState = stack.TCPRACKState{
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ }
return s
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index fc43c11e2..723e47ddc 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.EndpointState() {
- case StateInitial, StateBound:
- // TODO(b/138137272): this enumeration duplicates
- // EndpointState.connected. remove it.
- case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
@@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case StateListen, StateConnecting:
+ case epState == StateListen || epState == StateConnecting:
e.drainSegmentLocked()
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
- break
}
- case StateError, StateClose:
+ case epState.closed():
for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
@@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() {
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
-
- if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
- panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
- }
}
// saveAcceptedChan is invoked by stateify.
@@ -148,23 +144,23 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state EndpointState) {
+func (e *endpoint) loadState(epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
// For restore purposes we treat TimeWait like a connected endpoint.
- if state.connected() || state == StateTimeWait {
+ if epState.connected() || epState == StateTimeWait {
connectedLoading.Add(1)
}
- switch state {
- case StateListen:
+ switch {
+ case epState == StateListen:
listenLoading.Add(1)
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
// as the endpoint is still being loaded and the stack reference is not
// yet initialized.
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
}
// afterLoad is invoked by stateify.
@@ -183,33 +179,40 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- state := e.origEndpointState
- switch state {
+ epState := e.origEndpointState
+ switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- if len(e.BindAddr) == 0 {
- e.BindAddr = e.ID.LocalAddress
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ if err != nil {
+ panic("unable to parse BindAddr: " + err.String())
}
- addr := e.BindAddr
- port := e.ID.LocalPort
- if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
- panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
+ if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
+ e.isPortReserved = true
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
}
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ switch {
+ case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
e.connectingAddress = e.ID.RemoteAddress
@@ -232,13 +235,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
closed := e.closed
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
- if state == StateFinWait2 && closed {
+ if epState == StateFinWait2 && closed {
// If the endpoint has been closed then make sure we notify so
// that the FIN_WAIT2 timer is started after a restore.
e.notifyProtocolGoroutine(notifyClose)
}
connectedLoading.Done()
- case StateListen:
+ case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -255,7 +258,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -267,7 +270,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateBound:
+ case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -276,27 +279,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind()
tcpip.AsyncLoading.Done()
}()
- case StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.setEndpointState(StateClose)
- tcpip.AsyncLoading.Done()
- }()
- }
+ case epState == StateClose:
+ e.isPortReserved = false
e.state = StateClose
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
- case StateError:
+ case epState == StateError:
e.state = StateError
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
-
}
// saveLastError is invoked by stateify.
@@ -317,6 +309,16 @@ func (e *endpoint) loadLastError(s string) {
e.lastError = tcpip.StringToError(s)
}
+// saveRecentTSTime is invoked by stateify.
+func (e *endpoint) saveRecentTSTime() unixTime {
+ return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
+}
+
+// loadRecentTSTime is invoked by stateify.
+func (e *endpoint) loadRecentTSTime(unix unixTime) {
+ e.recentTSTime = time.Unix(unix.second, unix.nano)
+}
+
// saveHardError is invoked by stateify.
func (e *EndpointInfo) saveHardError() string {
if e.HardError == nil {
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 704d01c64..070b634b4 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a2a7ddeb..c5afa2680 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -61,6 +61,10 @@ const (
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
+ // MaxTCPLingerTimeout is the maximum amount of time that sockets
+ // linger in FIN_WAIT_2 state before being marked closed.
+ MaxTCPLingerTimeout = 120 * time.Second
+
// DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
// in TIME_WAIT state before being marked closed.
DefaultTCPTimeWaitTimeout = 60 * time.Second
@@ -70,34 +74,55 @@ const (
DefaultSynRetries = 6
)
-// SACKEnabled option can be used to enable SACK support in the TCP
-// protocol. See: https://tools.ietf.org/html/rfc2018.
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
+// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
-// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
+// Recovery is used by stack.(*Stack).TransportProtocolOption to
+// set loss detection algorithm in TCP.
+type Recovery int32
+
+const (
+ // RACKLossDetection indicates RACK is used for loss detection and
+ // recovery.
+ RACKLossDetection Recovery = 1 << iota
+
+ // RACKStaticReoWnd indicates the reordering window should not be
+ // adjusted when DSACK is received.
+ RACKStaticReoWnd
+
+ // RACKNoDupTh indicates RACK should not consider the classic three
+ // duplicate acknowledgements rule to mark the segments as lost. This
+ // is used when reordering is not detected.
+ RACKNoDupTh
+)
+
+// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
+// enable/disable Nagle's algorithm in TCP.
type DelayEnabled bool
-// SendBufferSizeOption allows the default, min and max send buffer sizes for
-// TCP endpoints to be queried or configured.
+// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
+// to get/set the default, min and max TCP send buffer sizes.
type SendBufferSizeOption struct {
Min int
Default int
Max int
}
-// ReceiveBufferSizeOption allows the default, min and max receive buffer size
-// for TCP endpoints to be queried or configured.
+// ReceiveBufferSizeOption is used by
+// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
+// TCP receive buffer sizes.
type ReceiveBufferSizeOption struct {
Min int
Default int
Max int
}
-const (
- ccReno = "reno"
- ccCubic = "cubic"
-)
-
// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
// value is protected by a mutex so that we can increment only when it's
// guaranteed not to go above a threshold.
@@ -158,20 +183,22 @@ func (s *synRcvdCounter) Threshold() uint64 {
type protocol struct {
mu sync.RWMutex
sackEnabled bool
+ recovery Recovery
delayEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
- tcpLingerTimeout time.Duration
- tcpTimeWaitTimeout time.Duration
+ lingerTimeout time.Duration
+ timeWaitTimeout time.Duration
+ timeWaitReuse tcpip.TCPTimeWaitReuseOption
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
synRcvdCount synRcvdCounter
synRetries uint8
- dispatcher *dispatcher
+ dispatcher dispatcher
}
// Number returns the tcp protocol number.
@@ -206,7 +233,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
p.dispatcher.queuePacket(r, ep, id, pkt)
}
@@ -217,7 +244,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
@@ -277,6 +304,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case Recovery:
+ p.mu.Lock()
+ p.recovery = Recovery(v)
+ p.mu.Unlock()
+ return nil
+
case DelayEnabled:
p.mu.Lock()
p.delayEnabled = bool(v)
@@ -325,7 +358,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
v = 0
}
p.mu.Lock()
- p.tcpLingerTimeout = time.Duration(v)
+ p.lingerTimeout = time.Duration(v)
p.mu.Unlock()
return nil
@@ -334,7 +367,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
v = 0
}
p.mu.Lock()
- p.tcpTimeWaitTimeout = time.Duration(v)
+ p.timeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitReuseOption:
+ if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.timeWaitReuse = v
p.mu.Unlock()
return nil
@@ -391,6 +433,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
+ case *Recovery:
+ p.mu.RLock()
+ *v = Recovery(p.recovery)
+ p.mu.RUnlock()
+ return nil
+
case *DelayEnabled:
p.mu.RLock()
*v = DelayEnabled(p.delayEnabled)
@@ -429,13 +477,19 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
case *tcpip.TCPLingerTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout)
p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitReuseOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
p.mu.RUnlock()
return nil
@@ -490,20 +544,51 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
return &p.synRcvdCount
}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TCP header is variable length, peek at it first.
+ hdrLen := header.TCPMinimumSize
+ hdr, ok := pkt.Data.PullUp(hdrLen)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
+ // packets.
+ hdrLen = offset
+ }
+
+ _, ok = pkt.TransportHeader().Consume(hdrLen)
+ return ok
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ p := protocol{
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
- tcpLingerTimeout: DefaultTCPLingerTimeout,
- tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ lingerTimeout: DefaultTCPLingerTimeout,
+ timeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
- dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
+ recovery: RACKLossDetection,
}
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
new file mode 100644
index 000000000..d969ca23a
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -0,0 +1,82 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// RACK is a loss detection algorithm used in TCP to detect packet loss and
+// reordering using transmission timestamp of the packets instead of packet or
+// sequence counts. To use RACK, SACK should be enabled on the connection.
+
+// rackControl stores the rack related fields.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1
+//
+// +stateify savable
+type rackControl struct {
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
+
+ // endSequence is the ending TCP sequence number of rackControl.seg.
+ endSequence seqnum.Value
+
+ // fack is the highest selectively or cumulatively acknowledged
+ // sequence.
+ fack seqnum.Value
+
+ // rtt is the RTT of the most recently delivered packet on the
+ // connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ rtt time.Duration
+}
+
+// Update will update the RACK related fields when an ACK has been received.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) {
+ rtt := time.Now().Sub(seg.xmitTime)
+
+ // If the ACK is for a retransmitted packet, do not update if it is a
+ // spurious inference which is determined by below checks:
+ // 1. When Timestamping option is available, if the TSVal is less than the
+ // transmit time of the most recent retransmitted packet.
+ // 2. When RTT calculated for the packet is less than the smoothed RTT
+ // for the connection.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // step 2
+ if seg.xmitCount > 1 {
+ if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ return
+ }
+ }
+ if rtt < srtt {
+ return
+ }
+ }
+
+ rc.rtt = rtt
+ // Update rc.xmitTime and rc.endSequence to the transmit time and
+ // ending sequence number of the packet which has been acknowledged
+ // most recently.
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
+ rc.xmitTime = seg.xmitTime
+ rc.endSequence = endSeq
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
new file mode 100644
index 000000000..c9dc7e773
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveXmitTime is invoked by stateify.
+func (rc *rackControl) saveXmitTime() unixTime {
+ return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (rc *rackControl) loadXmitTime(unix unixTime) {
+ rc.xmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..5e0bfe585 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// We only store the segment if it's within our buffer
// size limit.
if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ r.pendingBufUsed += seqnum.Size(s.segMemSize())
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.pendingBufUsed -= seqnum.Size(s.segMemSize())
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 074edded6..94307d31a 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -35,6 +35,7 @@ type segment struct {
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -60,13 +61,14 @@ type segment struct {
xmitCount uint32
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.data = pkt.Data.Clone(s.views[:])
+ s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
return s
}
@@ -136,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
@@ -146,12 +154,6 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h, ok := s.data.PullUp(header.TCPMinimumSize)
- if !ok {
- return false
- }
- hdr := header.TCP(h)
-
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -162,16 +164,12 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(hdr.DataOffset())
- if offset < header.TCPMinimumSize {
- return false
- }
- hdrWithOpts, ok := s.data.PullUp(offset)
- if !ok {
+ offset := int(s.hdr.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(s.hdr) {
return false
}
- s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
+ s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -180,22 +178,19 @@ func (s *segment) parse() bool {
if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
s.csumValid = true
verifyChecksum = false
- s.data.TrimFront(offset)
}
if verifyChecksum {
- hdr = header.TCP(hdrWithOpts)
- s.csum = hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = hdr.CalculateChecksum(xsum)
- s.data.TrimFront(offset)
+ s.csum = s.hdr.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
- s.ackNumber = seqnum.Value(hdr.AckNumber())
- s.flags = hdr.Flags()
- s.window = seqnum.Size(hdr.WindowSize())
+ s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(s.hdr.AckNumber())
+ s.flags = s.hdr.Flags()
+ s.window = seqnum.Size(s.hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 06dc9b7d7..c55589c45 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -191,6 +191,10 @@ type sender struct {
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
+
+ // rc has the fields needed for implementing RACK loss detection
+ // algorithm.
+ rc rackControl
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -618,6 +622,20 @@ func (s *sender) splitSeg(seg *segment, size int) {
nSeg.data.TrimFront(size)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
seg.data.CapLength(size)
}
@@ -739,7 +757,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(seg.sequenceNumber.Size(end))
+ available := int(s.sndNxt.Size(end))
if available > limit {
available = limit
}
@@ -782,8 +800,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// sent all at once.
return false
}
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
return false
}
}
@@ -824,10 +845,52 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if available == 0 {
return false
}
+
+ // If the whole segment or at least 1MSS sized segment cannot
+ // be accomodated in the receiver advertized window, skip
+ // splitting and sending of the segment. ref:
+ // net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered by
+ // a probe timer. On this condition, it defers the segment split
+ // and transmit to a short probe timer.
+ //
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split
+ // the segment right here if there are no pending segments. If
+ // there are pending segments, segment transmits are deferred to
+ // the retransmit timer handler.
+ if s.sndUna != s.sndNxt {
+ switch {
+ case available >= seg.data.Size():
+ // OK to send, the whole segments fits in the
+ // receiver's advertised window.
+ case available >= s.maxPayloadSize:
+ // OK to send, at least 1 MSS sized segment fits
+ // in the receiver's advertised window.
+ default:
+ return false
+ }
+ }
+
+ // The segment size limit is computed as a function of sender
+ // congestion window and MSS. When sender congestion window is >
+ // 1, this limit can be larger than MSS. Ensure that the
+ // currently available send space is not greater than minimum of
+ // this limit and MSS.
if available > limit {
available = limit
}
+ // If GSO is not in use then cap available to
+ // maxPayloadSize. When GSO is in use the gVisor GSO logic or
+ // the host GSO logic will cap the segment to the correct size.
+ if s.ep.gso == nil && available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
+
if seg.data.Size() > available {
s.splitSeg(seg, available)
}
@@ -1213,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
-func (s *sender) handleRcvdSegment(seg *segment) {
+func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && seg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
if s.ep.sackPermitted {
- for _, sb := range seg.parsedOptions.SACKBlocks {
+ for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
// * SACK block acks data after the ack number in the
@@ -1240,27 +1303,27 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
- seg.hasNewSACKInfo = true
+ rcvdSeg.hasNewSACKInfo = true
}
}
s.SetPipe()
}
// Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(seg)
+ rtx := s.checkDuplicateAck(rcvdSeg)
// Stash away the current window size.
- s.sndWnd = seg.window
+ s.sndWnd = rcvdSeg.window
- ack := seg.ackNumber
+ ack := rcvdSeg.ackNumber
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
// unacknowledged segment.
- if s.zeroWindowProbing && seg.window > 0 &&
+ if s.zeroWindowProbing && rcvdSeg.window > 0 &&
(ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
s.disableZeroWindowProbing()
}
@@ -1285,10 +1348,10 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
- elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
@@ -1302,6 +1365,9 @@ func (s *sender) handleRcvdSegment(seg *segment) {
ackLeft := acked
originalOutstanding := s.outstanding
+ s.rtt.Lock()
+ srtt := s.rtt.srtt
+ s.rtt.Unlock()
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1321,6 +1387,11 @@ func (s *sender) handleRcvdSegment(seg *segment) {
s.writeNext = seg.Next()
}
+ // Update the RACK fields if SACK is enabled.
+ if s.ep.sackPermitted {
+ s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset)
+ }
+
s.writeList.Remove(seg)
// if SACK is enabled then Only reduce outstanding if
@@ -1376,7 +1447,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
s.sendData()
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 5fe23113b..b9993ce1a 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -50,7 +50,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -90,14 +90,14 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
}
return nil
}
@@ -128,10 +128,10 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn = func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
return nil
}
@@ -215,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
expected := tcp.InitialCwnd
@@ -257,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -362,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -471,11 +471,11 @@ func TestRetransmit(t *testing.T) {
// MTU size though.
half := data[:len(data)/2]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
half = data[len(data)/2:]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -508,23 +508,23 @@ func TestRetransmit(t *testing.T) {
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
}
return nil
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
new file mode 100644
index 000000000..e03f101e8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+// TestRACKUpdate tests the RACK related fields are updated when an ACK is
+// received on a SACK enabled connection.
+func TestRACKUpdate(t *testing.T) {
+ const maxPayload = 10
+ const tsOptionSize = 12
+ const maxTCPOptionSize = 40
+
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ var xmitTime time.Time
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.RACKState is what we expect.
+ if state.Sender.RACKState.XmitTime.Before(xmitTime) {
+ t.Fatalf("RACK transmit time failed to update when an ACK is received")
+ }
+
+ gotSeq := state.Sender.RACKState.EndSequence
+ wantSeq := state.Sender.SndNxt
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ if state.Sender.RACKState.RTT == 0 {
+ t.Fatalf("RACK RTT failed to update when an ACK is received")
+ }
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ xmitTime = time.Now()
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ c.SendAck(790, bytesRead)
+ time.Sleep(200 * time.Millisecond)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ace79b7b2..99521f0c1 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -47,7 +47,7 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
}
}
@@ -400,7 +400,7 @@ func TestSACKRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -454,7 +454,7 @@ func TestSACKRecovery(t *testing.T) {
}
for _, s := range stats {
if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
}
}
return nil
@@ -529,19 +529,19 @@ func TestSACKRecovery(t *testing.T) {
// In SACK recovery only the first segment is fast retransmitted when
// entering recovery.
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
}
return nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6ef32a1b3..0f7e958e4 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -57,7 +57,7 @@ func TestGiveUpConnect(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -66,7 +66,7 @@ func TestGiveUpConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Close the connection, wait for completion.
@@ -75,21 +75,21 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
// and stats updates.
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -102,7 +102,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -115,10 +115,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
}
}
@@ -129,20 +129,38 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
+ }
+}
+
+func TestCloseWithoutConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ c.EP.Close()
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -156,10 +174,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
}
}
@@ -170,16 +188,16 @@ func TestTCPResetsSentIncrement(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
want := stats.TCP.SegmentsSent.Value() + 1
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -213,7 +231,7 @@ func TestTCPResetsSentIncrement(t *testing.T) {
metricPollFn := func() error {
if got := stats.TCP.ResetsSent.Value(); got != want {
- return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
}
return nil
}
@@ -292,7 +310,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err)
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
}
c.EP.Close()
@@ -355,7 +373,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
}
@@ -379,7 +397,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
}
@@ -403,7 +421,7 @@ func TestNonBlockingClose(t *testing.T) {
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -415,7 +433,7 @@ func TestConnectResetAfterClose(t *testing.T) {
// after 3 second in FIN_WAIT2 state.
tcpLingerTimeout := 3 * time.Second
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -497,11 +515,11 @@ func TestCurrentConnectedIncrement(t *testing.T) {
c.EP = nil
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
}
gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
}
ep.Close()
@@ -524,10 +542,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
})
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
}
// Ack and send FIN as well.
@@ -556,10 +574,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -575,7 +593,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
c.EP = nil
if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Send a FIN for ESTABLISHED --> CLOSED-WAIT
@@ -603,7 +621,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
time.Sleep(10 * time.Millisecond)
if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Close the application endpoint for CLOSE_WAIT --> LAST_ACK
@@ -620,7 +638,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
)
if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Pause the endpoint`s protocolMainLoop.
@@ -657,15 +675,15 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
// Expect the endpoint to be closed.
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
// Check if the endpoint was moved to CLOSED and netstack a reset in
@@ -691,7 +709,7 @@ func TestSimpleReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -714,7 +732,7 @@ func TestSimpleReceive(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -781,7 +799,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
// Start connection attempt to IPv4 address.
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet with our user supplied MSS.
@@ -842,7 +860,7 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
// Start connection attempt to IPv6 address.
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet with our user supplied MSS.
@@ -1239,7 +1257,7 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -1251,7 +1269,7 @@ func TestConnectBindToDevice(t *testing.T) {
),
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
@@ -1270,74 +1288,97 @@ func TestConnectBindToDevice(t *testing.T) {
c.GetPacket()
if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
})
}
}
-func TestRstOnSynSent(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
+func TestSynSent(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reset bool
+ }{
+ {"RstOnSynSent", true},
+ {"CloseOnSynSent", false},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
- }
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
- // Ensure that we've reached SynSent state
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to Error and close out
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
+ if test.reset {
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to error and close out.
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+ } else {
+ c.EP.Close()
+ }
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ if test.reset {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+ } else {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+ }
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ })
}
}
@@ -1352,7 +1393,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send second half of data first, with seqnum 3 ahead of expected.
@@ -1379,7 +1420,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send the first 3 bytes now.
@@ -1406,7 +1447,7 @@ func TestOutOfOrderReceive(t *testing.T) {
}
continue
}
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1436,7 +1477,7 @@ func TestOutOfOrderFlood(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 100 packets before the actual one that is expected.
@@ -1513,7 +1554,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1556,7 +1597,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// This final ACK should be ignored because an ACK on a reset doesn't mean
@@ -1582,7 +1623,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1624,7 +1665,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Cause a RST to be generated by closing the read end now since we have
@@ -1643,7 +1684,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// The ACK to the FIN should now be rejected since the connection has been
@@ -1665,19 +1706,19 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
}
}
@@ -1693,7 +1734,7 @@ func TestFullWindowReceive(t *testing.T) {
_, _, err := c.EP.Read(nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// Fill up the window.
@@ -1728,7 +1769,7 @@ func TestFullWindowReceive(t *testing.T) {
// Receive data and check it.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -1737,7 +1778,7 @@ func TestFullWindowReceive(t *testing.T) {
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
}
// Check that we get an ACK for the newly non-zero window.
@@ -1760,7 +1801,7 @@ func TestNoWindowShrinking(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -1768,7 +1809,7 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 3 bytes, check that the peer acknowledges them.
@@ -1832,7 +1873,7 @@ func TestNoWindowShrinking(t *testing.T) {
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1866,7 +1907,7 @@ func TestSimpleSend(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received.
@@ -1908,7 +1949,7 @@ func TestZeroWindowSend(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check if we got a zero-window probe.
@@ -1976,7 +2017,7 @@ func TestScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -2008,7 +2049,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2036,21 +2077,21 @@ func TestScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2068,7 +2109,7 @@ func TestScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2081,7 +2122,7 @@ func TestScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -2109,21 +2150,21 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
@@ -2142,7 +2183,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2155,7 +2196,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2244,7 +2285,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
for sz < defaultMTU {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
sz += len(v)
}
@@ -2311,7 +2352,7 @@ func TestSegmentMerging(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2381,7 +2422,7 @@ func TestDelay(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2428,7 +2469,7 @@ func TestUndelay(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2512,7 +2553,7 @@ func TestMSSNotDelayed(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2563,7 +2604,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received in chunks.
@@ -2631,7 +2672,7 @@ func TestSetTTL(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
@@ -2639,7 +2680,7 @@ func TestSetTTL(t *testing.T) {
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %s", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -2671,7 +2712,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
@@ -2683,11 +2724,11 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2705,7 +2746,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2794,7 +2835,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
select {
case err := <-ch:
if err != nil {
- t.Fatalf("Error creating endpoint: %v", err)
+ t.Fatalf("Error creating endpoint: %s", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2813,7 +2854,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Set the buffer size to a deterministic size so that we can check the
@@ -2830,7 +2871,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -2884,7 +2925,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
select {
case <-ch:
if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2899,22 +2940,22 @@ func TestCloseListener(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Close the listener and measure how long it takes.
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -2950,22 +2991,25 @@ loop:
case tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -2990,7 +3034,7 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
@@ -3013,7 +3057,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Expect first transmit and MaxRetries retransmits.
@@ -3048,7 +3092,10 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -3066,7 +3113,7 @@ func TestMaxRTO(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -3089,6 +3136,63 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -3097,7 +3201,7 @@ func TestFinImmediately(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3140,7 +3244,7 @@ func TestFinRetransmit(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3195,7 +3299,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3221,7 +3325,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Shutdown, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3268,7 +3372,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
}
@@ -3290,7 +3394,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// because the congestion window doesn't allow it. Wait until a
// retransmit is received.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3354,7 +3458,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3380,7 +3484,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3396,7 +3500,7 @@ func TestFinWithPendingData(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3441,7 +3545,7 @@ func TestFinWithPartialAck(t *testing.T) {
// FIN from the test side.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3478,7 +3582,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3494,7 +3598,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3540,20 +3644,20 @@ func TestUpdateListenBacklog(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Update the backlog with another Listen() on the same endpoint.
if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %v", err)
+ t.Fatalf("Listen failed to update backlog: %s", err)
}
ep.Close()
@@ -3585,7 +3689,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that only data that fits in the scaled window is sent.
@@ -3631,18 +3735,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
})
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
}
// Ensure there were no errors during handshake. If these stats have
// incremented, then the connection should not have been established.
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
}
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
}
}
@@ -3666,10 +3770,10 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c.SendSegment(vv)
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -3770,7 +3874,7 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
@@ -3789,7 +3893,7 @@ func TestReadAfterClosedState(t *testing.T) {
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Send some data and acknowledge the FIN.
@@ -3818,7 +3922,7 @@ func TestReadAfterClosedState(t *testing.T) {
time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Wait for receive to be notified.
@@ -3853,11 +3957,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3871,66 +3975,84 @@ func TestReusePort(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Second case, an endpoint that was bound and is connecting..
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Third case, an endpoint that was bound and is listening.
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
}
@@ -3939,11 +4061,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got receive buffer size = %v, want = %v", s, v)
+ t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
}
@@ -3952,11 +4074,11 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got send buffer size = %v, want = %v", s, v)
+ t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}
@@ -3969,7 +4091,7 @@ func TestDefaultBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer func() {
if ep != nil {
@@ -3981,28 +4103,34 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
@@ -4018,17 +4146,17 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Set values below the min.
@@ -4065,12 +4193,12 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
+ t.Errorf("CreateNIC failed: %s", err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -4095,12 +4223,12 @@ func TestBindToDeviceOption(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
+ t.Errorf("GetSockOpt got %s, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
t.Errorf("bindToDevice got %d, want %d", got, want)
@@ -4166,12 +4294,12 @@ func TestSelfConnect(t *testing.T) {
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -4180,12 +4308,12 @@ func TestSelfConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("Connect failed: %v", err)
+ t.Fatalf("Connect failed: %s", err)
}
// Write something.
@@ -4193,7 +4321,7 @@ func TestSelfConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Read back what was written.
@@ -4202,12 +4330,12 @@ func TestSelfConnect(t *testing.T) {
rd, _, err := ep.Read(nil)
if err != nil {
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
<-notifyCh
rd, _, err = ep.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
}
@@ -4291,7 +4419,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps = append(eps, ep)
switch network {
@@ -4342,7 +4470,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %v", i, err)
+ t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
want := tcpip.ErrConnectStarted
@@ -4350,7 +4478,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
want = tcpip.ErrNoPortAvailable
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
}
})
}
@@ -4384,7 +4512,7 @@ func TestPathMTUDiscovery(t *testing.T) {
}
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
@@ -4487,7 +4615,7 @@ func TestStackSetCongestionControl(t *testing.T) {
var oldCC tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
@@ -4574,12 +4702,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
var oldCC tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err)
}
if connected {
@@ -4587,12 +4715,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
}
if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err)
}
got, want := cc, oldCC
@@ -4615,7 +4743,7 @@ func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
}
}
@@ -4657,14 +4785,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -4744,15 +4872,18 @@ func TestKeepalive(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -4854,19 +4985,19 @@ func TestListenBacklogFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 2
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
for i := 0; i < listenBacklog; i++ {
@@ -4899,7 +5030,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4928,7 +5059,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4942,7 +5073,7 @@ func TestListenBacklogFull(t *testing.T) {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5162,19 +5293,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 1
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
@@ -5240,7 +5371,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5254,7 +5385,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5316,7 +5447,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5450,7 +5581,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
pkt := c.GetPacket()
tcpHdr = header.TCP(header.IPv4(pkt).Payload())
if string(tcpHdr.Payload()) != data {
- t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
}
}
@@ -5460,20 +5591,20 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
stats := c.Stack().Stats()
@@ -5494,7 +5625,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5503,7 +5634,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
}
if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -5514,14 +5645,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
srcPort := uint16(context.TestPort)
@@ -5546,10 +5677,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
time.Sleep(50 * time.Millisecond)
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -5564,7 +5695,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5579,28 +5710,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
+ t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
}
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -5617,7 +5748,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
case <-ch:
aep, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5625,25 +5756,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
}
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
ep.Close()
// Give worker goroutines time to receive the close notification.
time.Sleep(1 * time.Second)
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Accepted endpoint remains open when the listen endpoint is closed.
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
}
@@ -5663,13 +5794,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -5784,13 +5915,13 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
@@ -5935,7 +6066,7 @@ func TestDelayEnabled(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
}
checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
}
@@ -5946,7 +6077,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
var gotDelayEnabled tcp.DelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
if gotDelayEnabled != wantDelayEnabled {
t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
@@ -5954,7 +6085,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
@@ -5981,7 +6112,7 @@ func TestTCPLingerTimeout(t *testing.T) {
{"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
// Values > stack's TCPLingerTimeout are capped to the stack's
// value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -6515,10 +6646,10 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -6715,7 +6846,7 @@ func TestTCPUserTimeout(t *testing.T) {
// Send some data and wait before ACKing it.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -6765,11 +6896,14 @@ func TestTCPUserTimeout(t *testing.T) {
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6796,7 +6930,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Now receive 1 keepalives, but don't ACK it.
@@ -6837,10 +6971,13 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6896,11 +7033,11 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// ack should be sent in response to that. The window was not
// zero, but it grew to larger than MSS.
if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// After reading two packets, we surely crossed MSS. See the ack:
@@ -6997,13 +7134,13 @@ func TestTCPDeferAccept(t *testing.T) {
const tcpDeferAccept = 1 * time.Second
if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7026,7 +7163,7 @@ func TestTCPDeferAccept(t *testing.T) {
time.Sleep(50 * time.Millisecond)
aep, _, err := c.EP.Accept()
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
}
aep.Close()
@@ -7054,13 +7191,13 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
const tcpDeferAccept = 1 * time.Second
if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7094,7 +7231,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
time.Sleep(50 * time.Millisecond)
aep, _, err := c.EP.Accept()
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
}
aep.Close()
@@ -7160,3 +7297,53 @@ func TestResetDuringClose(t *testing.T) {
wg.Wait()
}
+
+func TestStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
+ }
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+}
+
+func TestSetStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ testCases := []struct {
+ v int
+ err *tcpip.Error
+ }{
+ {int(tcpip.TCPTimeWaitReuseDisabled), nil},
+ {int(tcpip.TCPTimeWaitReuseGlobal), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ }
+
+ for _, tc := range testCases {
+ err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v))
+ if got, want := err, tc.err; got != want {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err)
+ }
+ if tc.err != nil {
+ continue
+ }
+
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
+ }
+
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 7b1d72cf4..927bc71e0 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -143,13 +143,15 @@ func New(t *testing.T, mtu uint32) *Context {
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
@@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
@@ -255,8 +257,8 @@ func (c *Context) GetPacket() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
@@ -282,8 +284,8 @@ func (c *Context) GetPacketNonBlocking() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
@@ -316,9 +318,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
@@ -372,26 +375,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: s,
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: c.BuildSegment(payload, h),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendAck sends an ACK packet.
@@ -512,9 +518,8 @@ func (c *Context) GetV6Packet() []byte {
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
@@ -564,9 +569,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
}
// CreateConnected creates a connected TCP endpoint.
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index c70525f27..7981d469b 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
+ *t = timer{}
}
// checkExpiration checks if the given timer has actually expired, it should be
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 12bc1b5b5..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 647b2067a..73608783c 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,9 @@
package udp
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -93,6 +96,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
+ sndBufSizeMax int
state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
@@ -102,9 +106,10 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
- reusePort bool
+ portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
+ noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -158,7 +163,7 @@ type multicastMembership struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+ e := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -180,10 +185,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
multicastTTL: 1,
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
state: StateInitial,
uniqueID: s.UniqueID(),
}
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ return e
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -213,8 +231,8 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
}
@@ -247,11 +265,6 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -430,24 +443,33 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
var route *stack.Route
+ var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
route = &e.route
dstPort = e.dstPort
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route, given that it may need to
- // change in Route.Resolve() call below.
+ resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
if e.state != StateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
+ err = tcpip.ErrInvalidEndpointState
+ }
+ if err == nil && route.IsResolutionRequired() {
+ ch, err = route.Resolve(waker)
}
+
+ e.mu.Unlock()
+ e.mu.RLock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
+ }
+ return
}
} else {
// Reject destination address if it goes through a different
@@ -461,10 +483,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- if to.Addr == header.IPv4Broadcast && !e.broadcast {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -478,10 +496,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
route = &r
dstPort = dst.Port
+ resolve = route.Resolve
+ }
+
+ if !e.broadcast && route.IsOutboundBroadcast() {
+ return 0, nil, tcpip.ErrBroadcastDisabled
}
if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
+ if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -507,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -531,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -552,10 +580,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
case tcpip.ReusePortOption:
e.mu.Lock()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.mu.Unlock()
case tcpip.V6OnlyOption:
@@ -581,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -602,8 +640,43 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
+ e.mu.Unlock()
+ return nil
case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
+ }
+
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -743,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case tcpip.SocketDetachFilterOption:
+ return nil
}
return nil
}
@@ -765,6 +841,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -789,11 +871,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
case tcpip.ReuseAddressOption:
- return false, nil
+ e.mu.RLock()
+ v := e.portFlags.MostRecent
+ e.mu.RUnlock()
+
+ return v, nil
case tcpip.ReusePortOption:
e.mu.RLock()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.mu.RUnlock()
return v, nil
@@ -830,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
v := int(e.multicastTTL)
@@ -848,7 +938,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -895,22 +985,29 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
- // Allocate a buffer for the UDP header.
- hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: data,
+ })
+ pkt.Owner = owner
- // Initialize the header.
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
@@ -921,12 +1018,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{
- Header: hdr,
- Data: data,
- TransportHeader: buffer.View(udp),
- Owner: owner,
- }); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: ProtocolNumber,
+ TTL: ttl,
+ TOS: tos,
+ }, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
@@ -958,6 +1054,11 @@ func (e *endpoint) Disconnect() *tcpip.Error {
id stack.TransportEndpointID
btd tcpip.NICID
)
+
+ // We change this value below and we need the old value to unregister
+ // the endpoint.
+ boundPortFlags := e.boundPortFlags
+
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -970,16 +1071,17 @@ func (e *endpoint) Disconnect() *tcpip.Error {
return err
}
e.state = StateBound
+ boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
@@ -1051,6 +1153,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
+ oldPortFlags := e.boundPortFlags
+
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
@@ -1058,7 +1162,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
@@ -1122,22 +1226,17 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- // FIXME(b/129164367): Support SO_REUSEADDR.
- MostRecent: false,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return id, e.bindToDevice, err
}
- e.boundPortFlags = flags
id.LocalPort = port
}
+ e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
return id, e.bindToDevice, err
@@ -1269,22 +1368,47 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
+ hdr := header.UDP(pkt.TransportHeader().View())
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- pkt.Data.TrimFront(header.UDPMinimumSize)
+ // Never receive from a multicast address.
+ if header.IsV4MulticastAddress(id.RemoteAddress) ||
+ header.IsV6MulticastAddress(id.RemoteAddress) {
+ e.stack.Stats().UDP.InvalidSourceAddress.Increment()
+ e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1317,15 +1441,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Save any useful information from the network header to the packet.
switch r.NetProto {
case header.IPv4ProtocolNumber:
- packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.RemoteAddress
- packet.packetInfo.NIC = r.NICID()
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
case header.IPv6ProtocolNumber:
- packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
+ // address. packetInfo.LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.LocalAddress
+ packet.packetInfo.NIC = r.NICID()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
@@ -1336,7 +1463,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
defer e.mu.RUnlock()
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a674ceb68..c67e0ba95 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
@@ -61,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- pkt stack.PacketBuffer
+ pkt *stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -82,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
ep.RegisterNICID = r.route.NICID()
+ ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 52af6de22..63d4bed7c 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -32,9 +32,24 @@ import (
const (
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
)
-type protocol struct{}
+type protocol struct {
+}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -66,15 +81,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
- // Get the header then trim it from the view.
- h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- // Malformed packet.
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- if int(header.UDP(h).Length()) > pkt.Data.Size() {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ hdr := header.UDP(pkt.TransportHeader().View())
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
@@ -121,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
+ payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
@@ -130,20 +139,21 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
// For example, a raw or packet socket may use what UDP
// considers an unreachable destination. Thus we deep copy pkt
// to prevent multiple ownership and SR errors.
- newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- payload := newNetHeader.ToVectorisedView()
- payload.Append(pkt.Data.ToView().ToVectorisedView())
+ newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...)
+ newHeader = append(newHeader, pkt.TransportHeader().View()...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
payload.CapLength(payloadLen)
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4DstUnreachable)
- pkt.SetCode(header.ICMPv4PortUnreachable)
- pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
})
+ icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
case header.IPv6AddressSize:
if !r.Stack().AllowICMPMessage() {
@@ -164,34 +174,35 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ payloadLen := len(network) + len(transport) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader})
+ payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport})
payload.Append(pkt.Data)
payload.CapLength(payloadLen)
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
- pkt.SetType(header.ICMPv6DstUnreachable)
- pkt.SetCode(header.ICMPv6PortUnreachable)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
})
+ icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data))
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
}
return true
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -201,6 +212,12 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ return ok
+}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8acaa607a..71776d6db 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -370,8 +388,8 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
h := flow.header4Tuple(outgoing)
checkers = append(
@@ -391,17 +409,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -420,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -439,19 +455,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -485,13 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
+ return buf
}
func newPayload() []byte {
@@ -513,7 +516,7 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
@@ -647,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -658,19 +661,19 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
@@ -681,7 +684,7 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
@@ -691,7 +694,7 @@ func TestBindReservedPort(t *testing.T) {
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -701,11 +704,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -718,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -733,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -748,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -763,7 +766,7 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -800,7 +803,10 @@ func TestV4ReadSelfSource(t *testing.T) {
h := unicastV4.header4Tuple(incoming)
h.srcAddr = h.dstAddr
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
@@ -821,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -884,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
+// TestReadFromMulticaststats checks that a discarded packet
+// that that was sent with multicast SOURCE address increments
+// the correct counters and that a regular packet does not.
+func TestReadFromMulticastStats(t *testing.T) {
+ t.Helper()
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ var want uint64 = 0
+ if flow.isReverseMulticast() {
+ want = 1
+ }
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
+ t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
+ t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -959,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -1009,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -1026,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -1047,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1074,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1089,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1103,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1238,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1249,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
}
}
+func TestReadIPPacketInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedLocalAddr tcpip.Address
+ expectedDestAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedLocalAddr: stackAddr,
+ expectedDestAddr: stackAddr,
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastAddr,
+ expectedDestAddr: multicastAddr,
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: broadcastAddr,
+ expectedDestAddr: broadcastAddr,
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedLocalAddr: stackV6Addr,
+ expectedDestAddr: stackV6Addr,
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastV6Addr,
+ expectedDestAddr: multicastV6Addr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err)
+ }
+ }
+
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
+ t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
+ }
+
+ testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: 1,
+ LocalAddr: test.expectedLocalAddr,
+ DestinationAddr: test.expectedDestAddr,
+ }))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1263,6 +1422,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1510,12 +1693,12 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOpt failed: %s", err)
}
// Verify multicast interface addr and NIC were set correctly.
@@ -1523,7 +1706,7 @@ func TestMulticastInterfaceOption(t *testing.T) {
ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
+ c.t.Fatalf("GetSockOpt failed: %s", err)
}
if ifoptGot != ifoptWant {
c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
@@ -1583,9 +1766,8 @@ func TestV4UnknownDestination(t *testing.T) {
return
}
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1661,9 +1843,8 @@ func TestV6UnknownDestination(t *testing.T) {
return
}
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1695,7 +1876,7 @@ func TestV6UnknownDestination(t *testing.T) {
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1703,20 +1884,271 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
- var want uint64 = 1
+ // Invalidate the UDP header length field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+}
+
+// TestShortHeader verifies that when a packet with a too-short UDP header is
+// received, the malformed received global stat gets incremented.
+func TestShortHeader(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ h := unicastV6.header4Tuple(incoming)
+
+ // Allocate a buffer for an IPv6 and too-short UDP header.
+ const udpSize = header.UDPMinimumSize - 1
+ buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: header.UDPMinimumSize,
+ })
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+ // Copy all but the last byte of the UDP header into the packet.
+ copy(buf[header.IPv6MinimumSize:], udpHdr)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
}
}
@@ -1730,15 +2162,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1761,11 +2193,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
@@ -1807,3 +2239,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn
c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const nicID1 = 1
+
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ requiresBroadcastOpt bool
+ }{
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ requiresBroadcastOpt: false,
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
+ }
+ defer ep.Close()
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ to := tcpip.FullAddress{
+ Addr: test.remoteAddr,
+ Port: 80,
+ }
+ opts := tcpip.WriteOptions{To: &to}
+ expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ if !test.requiresBroadcastOpt {
+ expectedErrWithoutBcastOpt = nil
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+ })
+ }
+}