summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go197
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go175
-rw-r--r--pkg/tcpip/transport/tcp/BUILD11
-rw-r--r--pkg/tcpip/transport/tcp/accept.go166
-rw-r--r--pkg/tcpip/transport/tcp/connect.go71
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go148
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go159
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go39
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go45
-rw-r--r--pkg/tcpip/transport/tcp/snd.go63
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go32
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go608
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/tcp/timer.go1
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go134
-rw-r--r--pkg/tcpip/transport/udp/protocol.go21
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go459
22 files changed, 1674 insertions, 766 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 8ce294002..4612be4e7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+ }
return nil
}
@@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader)
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader)
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -786,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
},
}
- packet.data = pkt.Data
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader.ToVectorisedView()
+ packet.data.Append(pkt.Data)
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index baf08eda6..0e46e6355 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,6 +25,8 @@
package packet
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -43,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -71,11 +76,17 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
sndBufSize: 32 * 1024,
}
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
return nil, err
}
@@ -132,8 +154,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -158,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// TODO(b/129292371): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -215,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -228,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
@@ -264,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
@@ -274,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (ep *endpoint) takeLastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return ep.takeLastError()
+ }
return tcpip.ErrNotSupported
}
@@ -289,7 +381,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
@@ -323,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index a406d815e..f85a68554 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,8 @@
package raw
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -61,21 +63,23 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
rcvList rawPacketList
- rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
@@ -91,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -103,8 +107,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
}
// Unassociated endpoints are write-only and users call Write() with IP
@@ -168,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -201,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -244,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -305,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -340,17 +351,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
- break
+ if e.hdrIncluded {
+ if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
+ return 0, nil, err
}
-
+ } else {
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
@@ -359,9 +366,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -386,11 +390,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -456,14 +455,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -513,17 +506,69 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -543,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.KeepaliveEnabledOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -563,7 +614,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -582,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -627,16 +685,25 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
},
}
- headers := append(buffer.View(nil), pkt.NetworkHeader...)
- headers = append(headers, pkt.TransportHeader...)
- combinedVV := headers.ToVectorisedView()
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader))
+ headers = append(headers, pkt.NetworkHeader...)
+ headers = append(headers, pkt.TransportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView()
+ }
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
-
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index e26f01fae..18ff89ffc 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -76,7 +76,7 @@ go_library(
)
go_test(
- name = "tcp_test",
+ name = "tcp_x_test",
size = "medium",
srcs = [
"dual_stack_test.go",
@@ -86,6 +86,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
+ shard_count = 10,
deps = [
":tcp",
"//pkg/sync",
@@ -115,3 +116,11 @@ go_test(
"//pkg/tcpip/seqnum",
],
)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index ad197e8db..6e00e5526 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -199,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments. The endpoint is returned
-// with n.mu held.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -222,33 +220,12 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
- // Create sender and receiver.
- //
- // The receiver at least temporarily has a zero receive window scale,
- // but the caller may change it (before starting the protocol loop).
- n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- // Lock the endpoint before registering to ensure that no out of
- // band changes are possible due to incoming packets etc till
- // the endpoint is done initializing.
- n.mu.Lock()
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil {
- n.mu.Unlock()
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
- n.registeredReusePort = n.reusePort
-
- return n, nil
+ return n
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -259,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
- if err != nil {
- return nil, err
- }
+ ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ ep.mu.Lock()
ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
@@ -270,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
+
l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
@@ -290,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// to the newly created endpoint.
l.listenEP.propagateInheritableOptionsLocked(ep)
+ if !ep.reserveTupleLocked() {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
+ // Register new endpoint so that packets are routed to it.
+ if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
+
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+
+ ep.isRegistered = true
+
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ ep.notifyAborted()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -380,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
// Precondition: e.mu and n.mu must be held.
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
+ n.portFlags = e.portFlags
+ n.boundBindToDevice = e.boundBindToDevice
+ n.boundPortFlags = e.boundPortFlags
+}
+
+// reserveTupleLocked reserves an accepted endpoint's tuple.
+//
+// Preconditions:
+// * propagateInheritableOptionsLocked has been called.
+// * e.mu is held.
+func (e *endpoint) reserveTupleLocked() bool {
+ dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ if !e.stack.ReserveTuple(
+ e.effectiveNetProtos,
+ ProtocolNumber,
+ e.ID.LocalAddress,
+ e.ID.LocalPort,
+ e.boundPortFlags,
+ e.boundBindToDevice,
+ dest,
+ ) {
+ return false
+ }
+
+ e.isPortReserved = true
+ e.boundDest = dest
+ return true
+}
+
+// notifyAborted wakes up any waiters on registered, but not accepted
+// endpoints.
+//
+// This is strictly not required normally as a socket that was never accepted
+// can't really have any registered waiters except when stack.Wait() is called
+// which waits for all registered endpoints to stop and expects an EventHUp.
+func (e *endpoint) notifyAborted() {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -536,6 +570,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ iss := s.ackNumber - 1
+ irs := s.sequenceNumber - 1
+
// Since SYN cookies are in use this is potentially an ACK to a
// SYN-ACK we sent but don't have a half open connection state
// as cookies are being used to protect against a potential SYN
@@ -546,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// when under a potential syn flood attack.
//
// Validate the cookie.
- data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -571,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
- if err != nil {
+ n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+
+ n.mu.Lock()
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
+ if !n.reserveTupleLocked() {
+ n.mu.Unlock()
+ n.Close()
+
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
- // Propagate any inheritable options from the listening endpoint
- // to the newly created endpoint.
- e.propagateInheritableOptionsLocked(n)
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ n.isRegistered = true
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
@@ -589,10 +644,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
- // We do not use transitionToStateEstablishedLocked here as there is
- // no handshake state available when doing a SYN cookie based accept.
n.isConnectNotified = true
- n.setEndpointState(StateEstablished)
+ n.transitionToStateEstablishedLocked(&handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ })
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 7da93dcc4..1798510bc 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error {
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, func() {
- resendWaker.Assert()
- })
+ rt := time.AfterFunc(timeOut, resendWaker.Assert)
defer rt.Stop()
// Set up the wakers.
@@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -995,24 +999,22 @@ func (e *endpoint) completeWorkerLocked() {
// transitionToStateEstablisedLocked transitions a given endpoint
// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver if required.
+// It also initializes sender/receiver.
func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- if e.snd == nil {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- }
- if e.rcv == nil {
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvAutoParams.prevCopied = int(h.rcvWnd)
- e.rcvListMu.Unlock()
- }
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+
e.setEndpointState(StateEstablished)
}
@@ -1052,8 +1054,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
panic("current endpoint not removed from demuxer, enqueing segments to itself")
}
- if ep.(*endpoint).enqueueSegment(s) {
- ep.(*endpoint).newSegmentWaker.Assert()
+ if ep := ep.(*endpoint); ep.enqueueSegment(s) {
+ ep.newSegmentWaker.Assert()
}
}
@@ -1122,7 +1124,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ if e.EndpointState().closed() {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1442,9 +1444,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
- closeWaker.Assert()
- })
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
}
@@ -1462,7 +1462,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return err
}
}
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ if !e.EndpointState().closed() {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1518,6 +1518,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1527,7 +1528,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
loop:
- for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ for {
+ switch e.EndpointState() {
+ case StateTimeWait, StateClose, StateError:
+ break loop
+ }
+
e.mu.Unlock()
v, _ := s.Fetch(true)
e.mu.Lock()
@@ -1570,11 +1576,14 @@ loop:
reuseTW = e.doTimeWait()
}
- // Mark endpoint as closed.
- if e.EndpointState() != StateError {
- e.transitionToStateCloseLocked()
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
}
+ e.transitionToStateCloseLocked()
+
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 047704c80..98aecab9e 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -15,6 +15,8 @@
package tcp
import (
+ "encoding/binary"
+
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -66,89 +68,68 @@ func (q *epQueue) empty() bool {
// processor is responsible for processing packets queued to a tcp endpoint.
type processor struct {
epQ epQueue
+ sleeper sleep.Sleeper
newEndpointWaker sleep.Waker
closeWaker sleep.Waker
- id int
- wg sync.WaitGroup
-}
-
-func newProcessor(id int) *processor {
- p := &processor{
- id: id,
- }
- p.wg.Add(1)
- go p.handleSegments()
- return p
}
func (p *processor) close() {
p.closeWaker.Assert()
}
-func (p *processor) wait() {
- p.wg.Wait()
-}
-
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
p.newEndpointWaker.Assert()
}
-func (p *processor) handleSegments() {
- const newEndpointWaker = 1
- const closeWaker = 2
- s := sleep.Sleeper{}
- s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
- s.AddWaker(&p.closeWaker, closeWaker)
- defer s.Done()
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
for {
- id, ok := s.Fetch(true)
- if ok && id == closeWaker {
- p.wg.Done()
- return
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
}
- for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
if ep.segmentQueue.empty() {
continue
}
- // If socket has transitioned out of connected state
- // then just let the worker handle the packet.
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
//
- // NOTE: We read this outside of e.mu lock which means
- // that by the time we get to handleSegments the
- // endpoint may not be in ESTABLISHED. But this should
- // be fine as all normal shutdown states are handled by
- // handleSegments and if the endpoint moves to a
- // CLOSED/ERROR state then handleSegments is a noop.
- if ep.EndpointState() != StateEstablished {
- ep.newSegmentWaker.Assert()
- continue
- }
-
- if !ep.mu.TryLock() {
- ep.newSegmentWaker.Assert()
- continue
- }
- // If the endpoint is in a connected state then we do
- // direct delivery to ensure low latency and avoid
- // scheduler interactions.
- if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
- // Send any active resets if required.
- if err != nil {
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
}
- ep.notifyProtocolGoroutine(notifyTickleWorker)
ep.mu.Unlock()
- continue
- }
-
- if !ep.segmentQueue.empty() {
- p.epQ.enqueue(ep)
+ } else {
+ ep.newSegmentWaker.Assert()
}
-
- ep.mu.Unlock()
}
}
}
@@ -159,31 +140,36 @@ func (p *processor) handleSegments() {
// hash of the endpoint id to ensure that delivery for the same endpoint happens
// in-order.
type dispatcher struct {
- processors []*processor
+ processors []processor
seed uint32
-}
-
-func newDispatcher(nProcessors int) *dispatcher {
- processors := []*processor{}
- for i := 0; i < nProcessors; i++ {
- processors = append(processors, newProcessor(i))
- }
- return &dispatcher{
- processors: processors,
- seed: generateRandUint32(),
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
}
}
func (d *dispatcher) close() {
- for _, p := range d.processors {
- p.close()
+ for i := range d.processors {
+ d.processors[i].close()
}
}
func (d *dispatcher) wait() {
- for _, p := range d.processors {
- p.wait()
- }
+ d.wg.Wait()
}
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
@@ -231,20 +217,18 @@ func generateRandUint32() uint32 {
if _, err := rand.Read(b); err != nil {
panic(err)
}
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return binary.LittleEndian.Uint32(b)
}
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
- payload := []byte{
- byte(id.LocalPort),
- byte(id.LocalPort >> 8),
- byte(id.RemotePort),
- byte(id.RemotePort >> 8)}
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
h := jenkins.Sum32(d.seed)
- h.Write(payload)
+ h.Write(payload[:])
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
- return d.processors[h.Sum32()%uint32(len(d.processors))]
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6e4d607da..0f7487963 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -396,7 +396,8 @@ type endpoint struct {
mu sync.Mutex `state:"nosave"`
ownedByUser uint32
- // state must be read/set using the EndpointState()/setEndpointState() methods.
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -405,8 +406,8 @@ type endpoint struct {
origEndpointState EndpointState `state:"nosave"`
isPortReserved bool `state:"manual"`
- isRegistered bool
- boundNICID tcpip.NICID `state:"manual"`
+ isRegistered bool `state:"manual"`
+ boundNICID tcpip.NICID
route stack.Route `state:"manual"`
ttl uint8
v6only bool
@@ -415,10 +416,14 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // portFlags stores the current values of port related flags.
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
+ boundDest tcpip.FullAddress
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -426,7 +431,7 @@ type endpoint struct {
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
- effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -462,13 +467,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // reusePort is set to true if SO_REUSEPORT is enabled.
- reusePort bool
-
- // registeredReusePort is set if the current endpoint registration was
- // done with SO_REUSEPORT enabled.
- registeredReusePort bool
-
// bindToDevice is set to the NIC on which to bind or disabled if 0.
bindToDevice tcpip.NICID
@@ -488,7 +486,6 @@ type endpoint struct {
// The options below aren't implemented, but we remember the user
// settings because applications expect to be able to set/query these
// options.
- reuseAddr bool
// slowAck holds the negated state of quick ack. It is stubbed out and
// does nothing.
@@ -838,7 +835,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSize: DefaultReceiveBufferSize,
sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
- reuseAddr: true,
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -1025,15 +1021,15 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
- e.registeredReusePort = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
}
// Mark endpoint as closed.
@@ -1091,17 +1087,17 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
- e.registeredReusePort = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
}
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
e.route.Release()
e.stack.CompleteTransportEndpointCleanup(e)
@@ -1213,9 +1209,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
+ defer e.UnlockUser()
+
+ // When in SYN-SENT state, let the caller block on the receive.
+ // An application can initiate a non-blocking connect and then block
+ // on a receive. It can expect to read any data after the handshake
+ // is complete. RFC793, section 3.9, p58.
+ if e.EndpointState() == StateSynSent {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1225,7 +1239,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1235,7 +1248,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
- e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1522,12 +1534,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
case tcpip.ReuseAddressOption:
e.LockUser()
- e.reuseAddr = v
+ e.portFlags.TupleOnly = v
e.UnlockUser()
case tcpip.ReusePortOption:
e.LockUser()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.UnlockUser()
case tcpip.V6OnlyOption:
@@ -1585,6 +1597,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1781,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.deferAccept = time.Duration(v)
e.UnlockUser()
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
default:
return nil
}
@@ -1831,14 +1853,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.ReuseAddressOption:
e.LockUser()
- v := e.reuseAddr
+ v := e.portFlags.TupleOnly
e.UnlockUser()
return v, nil
case tcpip.ReusePortOption:
e.LockUser()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.UnlockUser()
return v, nil
@@ -1855,6 +1877,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.MulticastLoopOption:
+ return true, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1889,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := header.TCPDefaultMSS
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1922,6 +1952,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.UnlockUser()
return v, nil
+ case tcpip.MulticastTTLOption:
+ return 1, nil
+
default:
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1931,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -2085,8 +2114,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
defer r.Release()
- origID := e.ID
-
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.ID.LocalAddress = r.LocalAddress
e.ID.RemoteAddress = r.RemoteAddress
@@ -2094,11 +2121,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
- e.registeredReusePort = e.reusePort
} else {
// The endpoint doesn't have a local port yet, so try to get
// one. Make sure that it isn't one that will result in the same
@@ -2122,40 +2148,33 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- // reusePort is false below because connect cannot reuse a port even if
- // reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
return false, nil
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) {
- case nil:
- // Port picking successful. Save the details of
- // the selected port.
- e.ID = id
- e.boundBindToDevice = e.bindToDevice
- e.registeredReusePort = e.reusePort
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err == tcpip.ErrPortInUse {
+ return false, nil
+ }
return false, err
}
+
+ // Port picking successful. Save the details of
+ // the selected port.
+ e.ID = id
+ e.isPortReserved = true
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.boundDest = addr
+ return true, nil
}); err != nil {
return err
}
}
- // Remove the port reservation. This can happen when Bind is called
- // before Connect: in such a case we don't want to hold on to
- // reservations anymore.
- if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
- e.isPortReserved = false
- }
-
e.isRegistered = true
e.setEndpointState(StateConnecting)
e.route = r.Clone()
@@ -2334,13 +2353,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
e.isRegistered = true
e.setEndpointState(StateListen)
- e.registeredReusePort = e.reusePort
// The channel may be non-nil when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
@@ -2427,16 +2445,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return err
}
e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = flags
+ e.boundPortFlags = e.portFlags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
@@ -2444,7 +2459,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Any failures beyond this point must remove the port registration.
defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
@@ -2467,6 +2482,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
+ if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ return err
+ }
+
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2531,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index cbb779666..abf1ac5c9 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() {
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
-
- if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
- panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
- }
}
// saveAcceptedChan is invoked by stateify.
@@ -191,21 +187,28 @@ func (e *endpoint) Resume(s *stack.Stack) {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- if len(e.BindAddr) == 0 {
- e.BindAddr = e.ID.LocalAddress
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ if err != nil {
+ panic("unable to parse BindAddr: " + err.String())
}
- addr := e.BindAddr
- port := e.ID.LocalPort
- if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
- panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
+ if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
+ e.isPortReserved = true
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
}
switch {
@@ -277,17 +280,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
case epState == StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.setEndpointState(StateClose)
- tcpip.AsyncLoading.Done()
- }()
- }
+ e.isPortReserved = false
e.state = StateClose
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 73b8a6782..b34e47bbd 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -71,34 +71,36 @@ const (
DefaultSynRetries = 6
)
-// SACKEnabled option can be used to enable SACK support in the TCP
-// protocol. See: https://tools.ietf.org/html/rfc2018.
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
+// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
-// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
+// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
+// enable/disable Nagle's algorithm in TCP.
type DelayEnabled bool
-// SendBufferSizeOption allows the default, min and max send buffer sizes for
-// TCP endpoints to be queried or configured.
+// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
+// to get/set the default, min and max TCP send buffer sizes.
type SendBufferSizeOption struct {
Min int
Default int
Max int
}
-// ReceiveBufferSizeOption allows the default, min and max receive buffer size
-// for TCP endpoints to be queried or configured.
+// ReceiveBufferSizeOption is used by
+// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
+// TCP receive buffer sizes.
type ReceiveBufferSizeOption struct {
Min int
Default int
Max int
}
-const (
- ccReno = "reno"
- ccCubic = "cubic"
-)
-
// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
// value is protected by a mutex so that we can increment only when it's
// guaranteed not to go above a threshold.
@@ -172,7 +174,7 @@ type protocol struct {
maxRetries uint32
synRcvdCount synRcvdCounter
synRetries uint8
- dispatcher *dispatcher
+ dispatcher dispatcher
}
// Number returns the tcp protocol number.
@@ -513,18 +515,27 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ p := protocol{
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
- dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
}
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index acacb42e4..5862c32f2 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -833,25 +833,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
panic("Netstack queues FIN segments without data.")
}
- segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- // If the entire segment cannot be accomodated in the receiver
- // advertized window, skip splitting and sending of the segment.
- // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test()
- //
- // Linux checks this for all segment transmits not triggered
- // by a probe timer. On this condition, it defers the segment
- // split and transmit to a short probe timer.
- // ref: include/net/tcp.h::tcp_check_probe_timer()
- // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
- //
- // Instead of defining a new transmit timer, we attempt to split the
- // segment right here if there are no pending segments.
- // If there are pending segments, segment transmits are deferred
- // to the retransmit timer handler.
- if s.sndUna != s.sndNxt && !segEnd.LessThan(end) {
- return false
- }
-
if !seg.sequenceNumber.LessThan(end) {
return false
}
@@ -861,14 +842,48 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
return false
}
- // The segment size limit is computed as a function of sender congestion
- // window and MSS. When sender congestion window is > 1, this limit can
- // be larger than MSS. Ensure that the currently available send space
- // is not greater than minimum of this limit and MSS.
+ // If the whole segment or at least 1MSS sized segment cannot
+ // be accomodated in the receiver advertized window, skip
+ // splitting and sending of the segment. ref:
+ // net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered by
+ // a probe timer. On this condition, it defers the segment split
+ // and transmit to a short probe timer.
+ //
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split
+ // the segment right here if there are no pending segments. If
+ // there are pending segments, segment transmits are deferred to
+ // the retransmit timer handler.
+ if s.sndUna != s.sndNxt {
+ switch {
+ case available >= seg.data.Size():
+ // OK to send, the whole segments fits in the
+ // receiver's advertised window.
+ case available >= s.maxPayloadSize:
+ // OK to send, at least 1 MSS sized segment fits
+ // in the receiver's advertised window.
+ default:
+ return false
+ }
+ }
+
+ // The segment size limit is computed as a function of sender
+ // congestion window and MSS. When sender congestion window is >
+ // 1, this limit can be larger than MSS. Ensure that the
+ // currently available send space is not greater than minimum of
+ // this limit and MSS.
if available > limit {
available = limit
}
- if available > s.maxPayloadSize {
+
+ // If GSO is not in use then cap available to
+ // maxPayloadSize. When GSO is in use the gVisor GSO logic or
+ // the host GSO logic will cap the segment to the correct size.
+ if s.ep.gso == nil && available > s.maxPayloadSize {
available = s.maxPayloadSize
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 5fe23113b..b9993ce1a 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -50,7 +50,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -90,14 +90,14 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
}
return nil
}
@@ -128,10 +128,10 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn = func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
return nil
}
@@ -215,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
expected := tcp.InitialCwnd
@@ -257,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -362,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -471,11 +471,11 @@ func TestRetransmit(t *testing.T) {
// MTU size though.
half := data[:len(data)/2]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
half = data[len(data)/2:]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -508,23 +508,23 @@ func TestRetransmit(t *testing.T) {
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
}
return nil
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ace79b7b2..99521f0c1 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -47,7 +47,7 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
}
}
@@ -400,7 +400,7 @@ func TestSACKRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -454,7 +454,7 @@ func TestSACKRecovery(t *testing.T) {
}
for _, s := range stats {
if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
}
}
return nil
@@ -529,19 +529,19 @@ func TestSACKRecovery(t *testing.T) {
// In SACK recovery only the first segment is fast retransmitted when
// entering recovery.
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
}
return nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6ef32a1b3..e67ec42b1 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -57,7 +57,7 @@ func TestGiveUpConnect(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -66,7 +66,7 @@ func TestGiveUpConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Close the connection, wait for completion.
@@ -75,21 +75,21 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
// and stats updates.
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -102,7 +102,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -115,10 +115,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
}
}
@@ -129,20 +129,20 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
}
}
@@ -156,10 +156,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
}
}
@@ -170,16 +170,16 @@ func TestTCPResetsSentIncrement(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
want := stats.TCP.SegmentsSent.Value() + 1
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -213,7 +213,7 @@ func TestTCPResetsSentIncrement(t *testing.T) {
metricPollFn := func() error {
if got := stats.TCP.ResetsSent.Value(); got != want {
- return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
}
return nil
}
@@ -292,7 +292,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err)
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
}
c.EP.Close()
@@ -355,7 +355,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
}
@@ -379,7 +379,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
}
@@ -403,7 +403,7 @@ func TestNonBlockingClose(t *testing.T) {
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -415,7 +415,7 @@ func TestConnectResetAfterClose(t *testing.T) {
// after 3 second in FIN_WAIT2 state.
tcpLingerTimeout := 3 * time.Second
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -497,11 +497,11 @@ func TestCurrentConnectedIncrement(t *testing.T) {
c.EP = nil
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
}
gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
}
ep.Close()
@@ -524,10 +524,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
})
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
}
// Ack and send FIN as well.
@@ -556,10 +556,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -575,7 +575,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
c.EP = nil
if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Send a FIN for ESTABLISHED --> CLOSED-WAIT
@@ -603,7 +603,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
time.Sleep(10 * time.Millisecond)
if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Close the application endpoint for CLOSE_WAIT --> LAST_ACK
@@ -620,7 +620,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
)
if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Pause the endpoint`s protocolMainLoop.
@@ -657,15 +657,15 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
// Expect the endpoint to be closed.
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
// Check if the endpoint was moved to CLOSED and netstack a reset in
@@ -691,7 +691,7 @@ func TestSimpleReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -714,7 +714,7 @@ func TestSimpleReceive(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -781,7 +781,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
// Start connection attempt to IPv4 address.
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet with our user supplied MSS.
@@ -842,7 +842,7 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
// Start connection attempt to IPv6 address.
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet with our user supplied MSS.
@@ -1239,7 +1239,7 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -1251,7 +1251,7 @@ func TestConnectBindToDevice(t *testing.T) {
),
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
@@ -1270,7 +1270,7 @@ func TestConnectBindToDevice(t *testing.T) {
c.GetPacket()
if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
})
}
@@ -1291,7 +1291,7 @@ func TestRstOnSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -1332,7 +1332,7 @@ func TestRstOnSynSent(t *testing.T) {
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
}
// Due to the RST the endpoint should be in an error state.
@@ -1352,7 +1352,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send second half of data first, with seqnum 3 ahead of expected.
@@ -1379,7 +1379,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send the first 3 bytes now.
@@ -1406,7 +1406,7 @@ func TestOutOfOrderReceive(t *testing.T) {
}
continue
}
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1436,7 +1436,7 @@ func TestOutOfOrderFlood(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 100 packets before the actual one that is expected.
@@ -1513,7 +1513,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1556,7 +1556,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// This final ACK should be ignored because an ACK on a reset doesn't mean
@@ -1582,7 +1582,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1624,7 +1624,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Cause a RST to be generated by closing the read end now since we have
@@ -1643,7 +1643,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// The ACK to the FIN should now be rejected since the connection has been
@@ -1665,19 +1665,19 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
}
}
@@ -1693,7 +1693,7 @@ func TestFullWindowReceive(t *testing.T) {
_, _, err := c.EP.Read(nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// Fill up the window.
@@ -1728,7 +1728,7 @@ func TestFullWindowReceive(t *testing.T) {
// Receive data and check it.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -1737,7 +1737,7 @@ func TestFullWindowReceive(t *testing.T) {
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
}
// Check that we get an ACK for the newly non-zero window.
@@ -1760,7 +1760,7 @@ func TestNoWindowShrinking(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -1768,7 +1768,7 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 3 bytes, check that the peer acknowledges them.
@@ -1832,7 +1832,7 @@ func TestNoWindowShrinking(t *testing.T) {
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1866,7 +1866,7 @@ func TestSimpleSend(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received.
@@ -1908,7 +1908,7 @@ func TestZeroWindowSend(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check if we got a zero-window probe.
@@ -1976,7 +1976,7 @@ func TestScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -2008,7 +2008,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2036,21 +2036,21 @@ func TestScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2068,7 +2068,7 @@ func TestScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2081,7 +2081,7 @@ func TestScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -2109,21 +2109,21 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
@@ -2142,7 +2142,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2155,7 +2155,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2244,7 +2244,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
for sz < defaultMTU {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
sz += len(v)
}
@@ -2311,7 +2311,7 @@ func TestSegmentMerging(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2381,7 +2381,7 @@ func TestDelay(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2428,7 +2428,7 @@ func TestUndelay(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2512,7 +2512,7 @@ func TestMSSNotDelayed(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2563,7 +2563,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received in chunks.
@@ -2631,7 +2631,7 @@ func TestSetTTL(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
@@ -2639,7 +2639,7 @@ func TestSetTTL(t *testing.T) {
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %s", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -2671,7 +2671,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
@@ -2683,11 +2683,11 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2705,7 +2705,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2794,7 +2794,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
select {
case err := <-ch:
if err != nil {
- t.Fatalf("Error creating endpoint: %v", err)
+ t.Fatalf("Error creating endpoint: %s", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2813,7 +2813,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Set the buffer size to a deterministic size so that we can check the
@@ -2830,7 +2830,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -2884,7 +2884,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
select {
case <-ch:
if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2899,22 +2899,22 @@ func TestCloseListener(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Close the listener and measure how long it takes.
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -2950,22 +2950,25 @@ loop:
case tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -2990,7 +2993,7 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
@@ -3013,7 +3016,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Expect first transmit and MaxRetries retransmits.
@@ -3048,7 +3051,10 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -3066,7 +3072,7 @@ func TestMaxRTO(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -3089,6 +3095,63 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -3097,7 +3160,7 @@ func TestFinImmediately(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3140,7 +3203,7 @@ func TestFinRetransmit(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3195,7 +3258,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3221,7 +3284,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Shutdown, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3268,7 +3331,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
}
@@ -3290,7 +3353,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// because the congestion window doesn't allow it. Wait until a
// retransmit is received.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3354,7 +3417,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3380,7 +3443,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3396,7 +3459,7 @@ func TestFinWithPendingData(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3441,7 +3504,7 @@ func TestFinWithPartialAck(t *testing.T) {
// FIN from the test side.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3478,7 +3541,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3494,7 +3557,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3540,20 +3603,20 @@ func TestUpdateListenBacklog(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Update the backlog with another Listen() on the same endpoint.
if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %v", err)
+ t.Fatalf("Listen failed to update backlog: %s", err)
}
ep.Close()
@@ -3585,7 +3648,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that only data that fits in the scaled window is sent.
@@ -3631,18 +3694,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
})
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
}
// Ensure there were no errors during handshake. If these stats have
// incremented, then the connection should not have been established.
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
}
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
}
}
@@ -3666,10 +3729,10 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c.SendSegment(vv)
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -3770,7 +3833,7 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
@@ -3789,7 +3852,7 @@ func TestReadAfterClosedState(t *testing.T) {
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Send some data and acknowledge the FIN.
@@ -3818,7 +3881,7 @@ func TestReadAfterClosedState(t *testing.T) {
time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Wait for receive to be notified.
@@ -3853,11 +3916,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3871,66 +3934,84 @@ func TestReusePort(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Second case, an endpoint that was bound and is connecting..
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Third case, an endpoint that was bound and is listening.
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
}
@@ -3939,11 +4020,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got receive buffer size = %v, want = %v", s, v)
+ t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
}
@@ -3952,11 +4033,11 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got send buffer size = %v, want = %v", s, v)
+ t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}
@@ -3969,7 +4050,7 @@ func TestDefaultBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer func() {
if ep != nil {
@@ -3981,28 +4062,34 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
@@ -4018,17 +4105,17 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Set values below the min.
@@ -4065,12 +4152,12 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
+ t.Errorf("CreateNIC failed: %s", err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -4095,12 +4182,12 @@ func TestBindToDeviceOption(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
+ t.Errorf("GetSockOpt got %s, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
t.Errorf("bindToDevice got %d, want %d", got, want)
@@ -4166,12 +4253,12 @@ func TestSelfConnect(t *testing.T) {
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -4180,12 +4267,12 @@ func TestSelfConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("Connect failed: %v", err)
+ t.Fatalf("Connect failed: %s", err)
}
// Write something.
@@ -4193,7 +4280,7 @@ func TestSelfConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Read back what was written.
@@ -4202,12 +4289,12 @@ func TestSelfConnect(t *testing.T) {
rd, _, err := ep.Read(nil)
if err != nil {
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
<-notifyCh
rd, _, err = ep.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
}
@@ -4291,7 +4378,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps = append(eps, ep)
switch network {
@@ -4342,7 +4429,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %v", i, err)
+ t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
want := tcpip.ErrConnectStarted
@@ -4350,7 +4437,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
want = tcpip.ErrNoPortAvailable
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
}
})
}
@@ -4384,7 +4471,7 @@ func TestPathMTUDiscovery(t *testing.T) {
}
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
@@ -4487,7 +4574,7 @@ func TestStackSetCongestionControl(t *testing.T) {
var oldCC tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
@@ -4574,12 +4661,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
var oldCC tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err)
}
if connected {
@@ -4587,12 +4674,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
}
if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err)
}
got, want := cc, oldCC
@@ -4615,7 +4702,7 @@ func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
}
}
@@ -4657,14 +4744,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -4744,15 +4831,18 @@ func TestKeepalive(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -4854,19 +4944,19 @@ func TestListenBacklogFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 2
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
for i := 0; i < listenBacklog; i++ {
@@ -4899,7 +4989,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4928,7 +5018,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4942,7 +5032,7 @@ func TestListenBacklogFull(t *testing.T) {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5162,19 +5252,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 1
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
@@ -5240,7 +5330,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5254,7 +5344,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5316,7 +5406,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5450,7 +5540,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
pkt := c.GetPacket()
tcpHdr = header.TCP(header.IPv4(pkt).Payload())
if string(tcpHdr.Payload()) != data {
- t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
}
}
@@ -5460,20 +5550,20 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
stats := c.Stack().Stats()
@@ -5494,7 +5584,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5503,7 +5593,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
}
if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -5514,14 +5604,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
srcPort := uint16(context.TestPort)
@@ -5546,10 +5636,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
time.Sleep(50 * time.Millisecond)
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -5564,7 +5654,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5579,28 +5669,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
+ t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
}
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -5617,7 +5707,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
case <-ch:
aep, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5625,25 +5715,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
}
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
ep.Close()
// Give worker goroutines time to receive the close notification.
time.Sleep(1 * time.Second)
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Accepted endpoint remains open when the listen endpoint is closed.
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
}
@@ -5663,13 +5753,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -5784,13 +5874,13 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
@@ -5935,7 +6025,7 @@ func TestDelayEnabled(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
}
checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
}
@@ -5946,7 +6036,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
var gotDelayEnabled tcp.DelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
if gotDelayEnabled != wantDelayEnabled {
t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
@@ -5954,7 +6044,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
@@ -6515,10 +6605,10 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -6715,7 +6805,7 @@ func TestTCPUserTimeout(t *testing.T) {
// Send some data and wait before ACKing it.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -6765,11 +6855,14 @@ func TestTCPUserTimeout(t *testing.T) {
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6796,7 +6889,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Now receive 1 keepalives, but don't ACK it.
@@ -6837,10 +6930,13 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6896,11 +6992,11 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// ack should be sent in response to that. The window was not
// zero, but it grew to larger than MSS.
if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// After reading two packets, we surely crossed MSS. See the ack:
@@ -6997,13 +7093,13 @@ func TestTCPDeferAccept(t *testing.T) {
const tcpDeferAccept = 1 * time.Second
if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7026,7 +7122,7 @@ func TestTCPDeferAccept(t *testing.T) {
time.Sleep(50 * time.Millisecond)
aep, _, err := c.EP.Accept()
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
}
aep.Close()
@@ -7054,13 +7150,13 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
const tcpDeferAccept = 1 * time.Second
if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7094,7 +7190,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
time.Sleep(50 * time.Millisecond)
aep, _, err := c.EP.Accept()
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
}
aep.Close()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 9721f6caf..06fde2a79 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -144,12 +144,12 @@ func New(t *testing.T, mtu uint32) *Context {
})
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index c70525f27..7981d469b 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
+ *t = timer{}
}
// checkExpiration checks if the given timer has actually expired, it should be
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 12bc1b5b5..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index df5efbf6a..6e692da07 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,8 @@
package udp
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -94,6 +96,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
+ sndBufSizeMax int
state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
@@ -106,6 +109,7 @@ type endpoint struct {
portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
+ noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -159,7 +163,7 @@ type multicastMembership struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+ e := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -181,10 +185,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
multicastTTL: 1,
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
state: StateInitial,
uniqueID: s.UniqueID(),
}
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ return e
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -215,7 +232,7 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
}
@@ -513,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -537,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -590,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -611,8 +640,43 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
+ e.mu.Unlock()
+ return nil
case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
+ }
+
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -752,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case tcpip.SocketDetachFilterOption:
+ return nil
}
return nil
}
@@ -774,6 +841,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -843,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
v := int(e.multicastTTL)
@@ -861,7 +938,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -908,7 +985,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -922,8 +999,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
Length: length,
})
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
@@ -996,7 +1077,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
@@ -1147,7 +1228,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return id, e.bindToDevice, err
}
@@ -1157,7 +1238,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
return id, e.bindToDevice, err
@@ -1299,10 +1380,37 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- e.rcvMu.Lock()
+ // Never receive from a multicast address.
+ if header.IsV4MulticastAddress(id.RemoteAddress) ||
+ header.IsV6MulticastAddress(id.RemoteAddress) {
+ e.stack.Stats().UDP.InvalidSourceAddress.Increment()
+ e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
+
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1343,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 4218e7d03..0e7464e3a 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -32,9 +32,24 @@ import (
const (
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
)
-type protocol struct{}
+type protocol struct {
+}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -183,12 +198,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 313a3f117..90781cf49 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -391,17 +409,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -420,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -439,17 +455,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -483,11 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ return buf
}
func newPayload() []byte {
@@ -509,7 +516,7 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
@@ -643,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -654,19 +661,19 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
@@ -677,7 +684,7 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
@@ -687,7 +694,7 @@ func TestBindReservedPort(t *testing.T) {
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -697,11 +704,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -714,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -729,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -744,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -759,7 +766,7 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -796,7 +803,10 @@ func TestV4ReadSelfSource(t *testing.T) {
h := unicastV4.header4Tuple(incoming)
h.srcAddr = h.dstAddr
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
@@ -817,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -880,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
+// TestReadFromMulticaststats checks that a discarded packet
+// that that was sent with multicast SOURCE address increments
+// the correct counters and that a regular packet does not.
+func TestReadFromMulticastStats(t *testing.T) {
+ t.Helper()
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ var want uint64 = 0
+ if flow.isReverseMulticast() {
+ want = 1
+ }
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
+ t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
+ t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -955,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -1005,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -1022,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -1043,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1070,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1085,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1099,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1234,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1259,6 +1323,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1506,12 +1594,12 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOpt failed: %s", err)
}
// Verify multicast interface addr and NIC were set correctly.
@@ -1519,7 +1607,7 @@ func TestMulticastInterfaceOption(t *testing.T) {
ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
+ c.t.Fatalf("GetSockOpt failed: %s", err)
}
if ifoptGot != ifoptWant {
c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
@@ -1691,7 +1779,7 @@ func TestV6UnknownDestination(t *testing.T) {
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1699,20 +1787,27 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
- var want uint64 = 1
+ // Invalidate the UDP header length field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -1728,7 +1823,6 @@ func TestShortHeader(t *testing.T) {
c.t.Fatalf("Bind failed: %s", err)
}
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
// Allocate a buffer for an IPv6 and too-short UDP header.
@@ -1768,6 +1862,199 @@ func TestShortHeader(t *testing.T) {
}
}
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
// TestShutdownRead verifies endpoint read shutdown and error
// stats increment on packet receive.
func TestShutdownRead(t *testing.T) {
@@ -1778,15 +2065,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1809,11 +2096,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)