summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/BUILD4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go221
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go38
-rw-r--r--pkg/tcpip/transport/raw/BUILD4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go369
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/BUILD8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go48
-rw-r--r--pkg/tcpip/transport/tcp/connect.go74
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go56
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go586
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go28
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go26
-rw-r--r--pkg/tcpip/transport/tcp/snd.go12
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go576
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go52
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/BUILD8
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go373
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go18
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go8
-rw-r--r--pkg/tcpip/transport/udp/protocol.go113
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go632
27 files changed, 2430 insertions, 878 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index d78a162b8..9254c3dea 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "icmp_packet_list",
out = "icmp_packet_list.go",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 451d3880e..3187b336b 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,7 +15,6 @@
package icmp
import (
- "encoding/binary"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -53,11 +52,11 @@ const (
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
- transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
@@ -74,27 +73,23 @@ type endpoint struct {
sndBufSize int
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
- id stack.TransportEndpointID
state endpointState
- // bindNICID and bindAddr are set via calls to Bind(). They are used to
- // reject attempts to send data or connect via a different NIC or
- // address
- bindNICID tcpip.NICID
- bindAddr tcpip.Address
- // regNICID is the default NIC to be used when callers don't specify a
- // NIC.
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
-}
-
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ route stack.Route `state:"manual"`
+ ttl uint8
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+}
+
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return &endpoint{
- stack: stack,
- netProto: netProto,
- transProto: transProto,
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ state: stateInitial,
}, nil
}
@@ -105,7 +100,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -144,6 +139,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
@@ -205,7 +201,30 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -256,12 +275,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
- if e.bindNICID != 0 {
- if nicid != 0 && nicid != e.bindNICID {
+ if e.BindNICID != 0 {
+ if nicid != 0 && nicid != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
}
toCopy := *to
@@ -272,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
// Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -290,17 +309,17 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- switch e.netProto {
+ switch e.NetProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.id.LocalPort, v)
+ err = send4(route, e.ID.LocalPort, v, e.ttl)
case header.IPv6ProtocolNumber:
- err = send6(route, e.id.LocalPort, v)
+ err = send6(route, e.ID.LocalPort, v, e.ttl)
}
if err != nil {
@@ -315,8 +334,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(o)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
return nil
}
@@ -332,6 +363,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -342,40 +385,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
return nil
- case *tcpip.ReceiveBufferSizeOption:
+ case *tcpip.TTLOption:
e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ *o = tcpip.TTLOption(e.ttl)
e.rcvMu.Unlock()
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident to the user-specified port. Sequence number should
- // already be set by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident)
-
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(icmpv4, data)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ icmpv4.SetIdent(ident)
data = data[header.ICMPv4MinimumSize:]
// Linux performs these basic checks.
@@ -386,22 +422,24 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
}
-func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
-
- hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
copy(icmpv6, data)
- data = data[header.ICMPv6EchoMinimumSize:]
+ // Set the ident. Sequence number is provided by the user.
+ icmpv6.SetIdent(ident)
+ data = data[header.ICMPv6MinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return tcpip.ErrInvalidEndpointState
@@ -410,18 +448,21 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
icmpv6.SetChecksum(0)
icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if header.IsV4MappedAddress(addr.Addr) {
return 0, tcpip.ErrNoRoute
}
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -442,16 +483,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
localPort := uint16(0)
switch e.state {
case stateBound, stateConnected:
- localPort = e.id.LocalPort
- if e.bindNICID == 0 {
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
break
}
- if nicid != 0 && nicid != e.bindNICID {
+ if nicid != 0 && nicid != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -462,7 +503,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -484,9 +525,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
+ e.ID = id
e.route = r.Clone()
- e.regNICID = nicid
+ e.RegisterNICID = nicid
e.state = stateConnected
@@ -541,14 +582,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -595,8 +636,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
- e.regNICID = addr.NIC
+ e.ID = id
+ e.RegisterNICID = addr.NIC
// Mark endpoint as bound.
e.state = stateBound
@@ -619,8 +660,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.bindNICID = addr.NIC
- e.bindAddr = addr.Addr
+ e.BindNICID = addr.NIC
+ e.BindAddr = addr.Addr
return nil
}
@@ -631,9 +672,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
}, nil
}
@@ -647,9 +688,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
}, nil
}
@@ -675,17 +716,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
// Only accept echo replies.
- switch e.netProto {
+ switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(vv.First())
if h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(vv.First())
if h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
}
@@ -693,9 +736,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
- if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
}
@@ -717,7 +768,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
pkt.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
-
+ e.stats.PacketsReceived.Increment()
// Notify any waiters that there's data to be read now.
if wasEmpty {
e.waiterQueue.Notify(waiter.EventIn)
@@ -733,3 +784,17 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
func (e *endpoint) State() uint32 {
return 0
}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index c587b96b6..9d263c0ec 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -76,19 +76,19 @@ func (e *endpoint) Resume(s *stack.Stack) {
var err *tcpip.Error
if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
if err != nil {
panic(err)
}
- e.id.LocalAddress = e.route.LocalAddress
- } else if len(e.id.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ e.ID.LocalAddress = e.route.LocalAddress
+ } else if len(e.ID.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7fdba5d56..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,16 +14,14 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
import (
- "encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -92,7 +84,7 @@ func (p *protocol) MinimumPacketSize() int {
case ProtocolNumber4:
return header.ICMPv4MinimumSize
case ProtocolNumber6:
- return header.ICMPv6EchoMinimumSize
+ return header.ICMPv6MinimumSize
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
@@ -101,16 +93,18 @@ func (p *protocol) MinimumPacketSize() int {
func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
switch p.number {
case ProtocolNumber4:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil
+ hdr := header.ICMPv4(v)
+ return 0, hdr.Ident(), nil
case ProtocolNumber6:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ hdr := header.ICMPv6(v)
+ return 0, hdr.Ident(), nil
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -124,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 7241f6c19..fba598d51 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "packet_list",
out = "packet_list.go",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 13e17e2a6..b4c660859 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -62,11 +62,10 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
- transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
@@ -84,18 +83,10 @@ type endpoint struct {
closed bool
connected bool
bound bool
- // registeredNIC is the NIC to which th endpoint is explicitly
- // registered. Is set when Connect or Bind are used to specify a NIC.
- registeredNIC tcpip.NICID
- // boundNIC and boundAddr are set on calls to Bind(). When callers
- // attempt actions that would invalidate the binding data (e.g. sending
- // data via a NIC other than boundNIC), the endpoint will return an
- // error.
- boundNIC tcpip.NICID
- boundAddr tcpip.Address
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route stack.Route `state:"manual"`
+ stats tcpip.TransportEndpointStats `state:"nosave"`
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -104,15 +95,17 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
if netProto != header.IPv4ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
- ep := &endpoint{
- stack: stack,
- netProto: netProto,
- transProto: transProto,
+ e := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
@@ -123,81 +116,82 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- ep.rcvBufSizeMax = 0
- ep.waiterQueue = nil
- return ep, nil
+ e.rcvBufSizeMax = 0
+ e.waiterQueue = nil
+ return e, nil
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
return nil, err
}
- return ep, nil
+ return e, nil
}
// Close implements tcpip.Endpoint.Close.
-func (ep *endpoint) Close() {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
- if ep.closed || !ep.associated {
+ if e.closed || !e.associated {
return
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
- ep.rcvMu.Lock()
- defer ep.rcvMu.Unlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
// Clear the receive list.
- ep.rcvClosed = true
- ep.rcvBufSize = 0
- for !ep.rcvList.Empty() {
- ep.rcvList.Remove(ep.rcvList.Front())
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ e.rcvList.Remove(e.rcvList.Front())
}
- if ep.connected {
- ep.route.Release()
- ep.connected = false
+ if e.connected {
+ e.route.Release()
+ e.connected = false
}
- ep.closed = true
+ e.closed = true
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (ep *endpoint) ModerateRecvBuf(copied int) {}
+func (e *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (iptables.IPTables, error) {
- return ep.stack.IPTables(), nil
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
}
// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !ep.associated {
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ if !e.associated {
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
}
- ep.rcvMu.Lock()
+ e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
// endpoint is closed.
- if ep.rcvList.Empty() {
+ if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
- if ep.rcvClosed {
+ if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
return buffer.View{}, tcpip.ControlMessages{}, err
}
- packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ packet := e.rcvList.Front()
+ e.rcvList.Remove(packet)
+ e.rcvBufSize -= packet.data.Size()
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
if addr != nil {
*addr = packet.senderAddr
@@ -207,31 +201,54 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- ep.mu.RLock()
+ e.mu.RLock()
- if ep.closed {
- ep.mu.RUnlock()
+ if e.closed {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := payload.Get(payload.Size())
+ payloadBytes, err := p.FullPayload()
if err != nil {
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return 0, nil, err
}
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !ep.associated {
+ if !e.associated {
ip := header.IPv4(payloadBytes)
- if !ip.IsValid(payload.Size()) {
- ep.mu.RUnlock()
+ if !ip.IsValid(len(payloadBytes)) {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -252,66 +269,66 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
if opts.To == nil {
// If the user doesn't specify a destination, they should have
// connected to another address.
- if !ep.connected {
- ep.mu.RUnlock()
+ if !e.connected {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if ep.route.IsResolutionRequired() {
- savedRoute := &ep.route
+ if e.route.IsResolutionRequired() {
+ savedRoute := &e.route
// Promote lock to exclusive if using a shared route,
// given that it may need to change in finishWrite.
- ep.mu.RUnlock()
- ep.mu.Lock()
+ e.mu.RUnlock()
+ e.mu.Lock()
// Make sure that the route didn't change during the
// time we didn't hold the lock.
- if !ep.connected || savedRoute != &ep.route {
- ep.mu.Unlock()
+ if !e.connected || savedRoute != &e.route {
+ e.mu.Unlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
- n, ch, err := ep.finishWrite(payloadBytes, savedRoute)
- ep.mu.Unlock()
+ n, ch, err := e.finishWrite(payloadBytes, savedRoute)
+ e.mu.Unlock()
return n, ch, err
}
- n, ch, err := ep.finishWrite(payloadBytes, &ep.route)
- ep.mu.RUnlock()
+ n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ e.mu.RUnlock()
return n, ch, err
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
- if ep.bound && nic != 0 && nic != ep.boundNIC {
- ep.mu.RUnlock()
+ if e.bound && nic != 0 && nic != e.BindNICID {
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
// We don't support IPv6 yet, so this has to be an IPv4 address.
if len(opts.To.Addr) != header.IPv4AddressSize {
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
- // Find the route to the destination. If boundAddress is 0,
+ // Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
- route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false)
+ route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return 0, nil, err
}
- n, ch, err := ep.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, &route)
route.Release()
- ep.mu.RUnlock()
+ e.mu.RUnlock()
return n, ch, err
}
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -324,16 +341,16 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch ep.netProto {
+ switch e.NetProto {
case header.IPv4ProtocolNumber:
- if !ep.associated {
+ if !e.associated {
if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
return 0, nil, err
}
break
}
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
+ if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
return 0, nil, err
}
@@ -345,7 +362,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -355,11 +372,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
}
// Connect implements tcpip.Endpoint.Connect.
-func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
- if ep.closed {
+ if e.closed {
return tcpip.ErrInvalidEndpointState
}
@@ -369,15 +386,15 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
nic := addr.NIC
- if ep.bound {
- if ep.boundNIC == 0 {
+ if e.bound {
+ if e.BindNICID == 0 {
// If we're bound, but not to a specific NIC, the NIC
// in addr will be used. Nothing to do here.
} else if addr.NIC == 0 {
// If we're bound to a specific NIC, but addr doesn't
// specify a NIC, use the bound NIC.
- nic = ep.boundNIC
- } else if addr.NIC != ep.boundNIC {
+ nic = e.BindNICID
+ } else if addr.NIC != e.BindNICID {
// We're bound and addr specifies a NIC. They must be
// the same.
return tcpip.ErrInvalidEndpointState
@@ -385,53 +402,53 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the destination.
- route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false)
+ route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false)
if err != nil {
return err
}
defer route.Release()
- if ep.associated {
+ if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
return err
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- ep.registeredNIC = nic
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.RegisterNICID = nic
}
// Save the route we've connected via.
- ep.route = route.Clone()
- ep.connected = true
+ e.route = route.Clone()
+ e.connected = true
return nil
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
-func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
- if !ep.connected {
+ if !e.connected {
return tcpip.ErrNotConnected
}
return nil
}
// Listen implements tcpip.Endpoint.Listen.
-func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept.
-func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
// Bind implements tcpip.Endpoint.Bind.
-func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
- ep.mu.Lock()
- defer ep.mu.Unlock()
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
// Callers must provide an IPv4 address or no network address (for
// binding to a NIC, but not an address).
@@ -440,94 +457,100 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 {
+ if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
- if ep.associated {
+ if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
return err
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- ep.registeredNIC = addr.NIC
- ep.boundNIC = addr.NIC
+ e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.RegisterNICID = addr.NIC
+ e.BindNICID = addr.NIC
}
- ep.boundAddr = addr.Addr
- ep.bound = true
+ e.BindAddr = addr.Addr
+ e.bound = true
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
// Readiness implements tcpip.Endpoint.Readiness.
-func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// The endpoint is always writable.
result := waiter.EventOut & mask
// Determine whether the endpoint is readable.
if (mask & waiter.EventIn) != 0 {
- ep.rcvMu.Lock()
- if !ep.rcvList.Empty() || ep.rcvClosed {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
result |= waiter.EventIn
}
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
}
return result
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
- ep.rcvMu.Lock()
- if !ep.rcvList.Empty() {
- p := ep.rcvList.Front()
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
v = p.data.Size()
}
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -538,37 +561,45 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
- ep.rcvMu.Lock()
+func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
- if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax {
- ep.stack.Stats().DroppedPackets.Increment()
- ep.rcvMu.Unlock()
+ if e.rcvClosed {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
}
- if ep.bound {
+ if e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ if e.bound {
// If bound to a NIC, only accept data for that NIC.
- if ep.boundNIC != 0 && ep.boundNIC != route.NICID() {
- ep.rcvMu.Unlock()
+ if e.BindNICID != 0 && e.BindNICID != route.NICID() {
+ e.rcvMu.Unlock()
return
}
// If bound to an address, only accept data for that address.
- if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress {
- ep.rcvMu.Unlock()
+ if e.BindAddr != "" && e.BindAddr != route.RemoteAddress {
+ e.rcvMu.Unlock()
return
}
}
// If connected, only accept packets from the remote address we
// connected to.
- if ep.connected && ep.route.RemoteAddress != route.RemoteAddress {
- ep.rcvMu.Unlock()
+ if e.connected && e.route.RemoteAddress != route.RemoteAddress {
+ e.rcvMu.Unlock()
return
}
- wasEmpty := ep.rcvBufSize == 0
+ wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
packet := &packet{
@@ -581,20 +612,34 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b
combinedVV := netHeader.ToVectorisedView()
combinedVV.Append(vv)
packet.data = combinedVV.Clone(packet.views[:])
- packet.timestampNS = ep.stack.NowNanoseconds()
-
- ep.rcvList.PushBack(packet)
- ep.rcvBufSize += packet.data.Size()
+ packet.timestampNS = e.stack.NowNanoseconds()
- ep.rcvMu.Unlock()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.rcvMu.Unlock()
+ e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
- ep.waiterQueue.Notify(waiter.EventIn)
+ e.waiterQueue.Notify(waiter.EventIn)
}
}
// State implements socket.Socket.State.
-func (ep *endpoint) State() uint32 {
+func (e *endpoint) State() uint32 {
return 0
}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 168953dec..a6c7cc43a 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,7 @@ func (ep *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if ep.connected {
var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
if err != nil {
panic(err)
}
@@ -81,12 +81,12 @@ func (ep *endpoint) Resume(s *stack.Stack) {
// If the endpoint is bound, re-bind.
if ep.bound {
- if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 1ee1a53f8..aed70e06f 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "tcp_segment_list",
@@ -47,6 +48,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e9c5099ea..844959fa0 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -143,6 +143,15 @@ func decSynRcvdCount() {
synRcvdCount.Unlock()
}
+// synCookiesInUse() returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func synCookiesInUse() bool {
+ synRcvdCount.Lock()
+ v := synRcvdCount.value
+ synRcvdCount.Unlock()
+ return v >= SynRcvdCountThreshold
+}
+
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
@@ -220,7 +229,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
}
n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only
- n.id = s.id
+ n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
@@ -233,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -281,7 +290,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
- ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -302,14 +310,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- l.pendingEndpoints[n.id] = n
+ l.pendingEndpoints[n.ID] = n
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.id)
+ delete(l.pendingEndpoints, n.ID)
l.pending.Done()
l.pendingMu.Unlock()
}
@@ -354,6 +362,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
@@ -405,6 +414,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
decSynRcvdCount()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
} else {
@@ -412,6 +422,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// is full then drop the syn.
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
}
@@ -430,7 +441,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TSEcr: opts.TSVal,
MSS: uint16(mss),
}
- sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
@@ -442,10 +453,32 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// complete the connection at the time of retransmit if
// the backlog has space.
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
}
+ if !synCookiesInUse() {
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ replyWithReset(s)
+ return
+ }
+
+ // Since SYN cookies are in use this is potentially an ACK to a
+ // SYN-ACK we sent but don't have a half open connection state
+ // as cookies are being used to protect against a potential SYN
+ // flood. In such cases validate the cookie and if valid create
+ // a fully connected endpoint and deliver to the accept queue.
+ //
+ // If not, silently drop the ACK to avoid leaking information
+ // when under a potential syn flood attack.
+ //
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
if !ok || int(data) >= len(mssTable) {
@@ -475,6 +508,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
return
}
@@ -506,7 +540,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 00d2ae524..5ea036bea 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -238,6 +238,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
h.ep.state = StateSynRecv
+ ttl := h.ep.ttl
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
@@ -251,8 +252,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
SACKPermitted: rcvSynOpts.SACKPermitted,
MSS: h.ep.amss,
}
- sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
-
+ if ttl == 0 {
+ ttl = s.route.DefaultTTL()
+ }
+ h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
}
@@ -296,7 +299,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
}
@@ -383,6 +386,11 @@ func (h *handshake) resolveRoute() *tcpip.Error {
switch index {
case wakerForResolution:
if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ if err == tcpip.ErrNoLinkAddress {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ } else if err != nil {
+ h.ep.stats.SendErrors.NoRoute.Increment()
+ }
// Either success (err == nil) or failure.
return err
}
@@ -460,7 +468,8 @@ func (h *handshake) execute() *tcpip.Error {
synOpts.WS = -1
}
}
- sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
for h.state != handshakeCompleted {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
@@ -469,7 +478,7 @@ func (h *handshake) execute() *tcpip.Error {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification:
n := h.ep.fetchNotifications()
@@ -579,16 +588,28 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
return options[:offset]
}
-func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
options := makeSynOptions(opts)
- err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
+ // We ignore SYN send errors and let the callers re-attempt send.
+ if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
+ e.stats.SendErrors.SynSendToNetworkFailed.Increment()
+ }
putOptions(options)
- return err
+ return nil
+}
+
+func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+ e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
+ return err
+ }
+ e.stats.SegmentsSent.Increment()
+ return nil
}
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -624,12 +645,18 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ r.Stats().TCP.SegmentSendErrors.Increment()
+ return err
+ }
r.Stats().TCP.SegmentsSent.Increment()
if (flags & header.TCPFlagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
-
- return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl)
+ return nil
}
// makeOptions makes an options slice.
@@ -678,7 +705,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso)
+ err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
putOptions(options)
return err
}
@@ -720,13 +747,18 @@ func (e *endpoint) handleClose() *tcpip.Error {
return nil
}
-// resetConnectionLocked sends a RST segment and puts the endpoint in an error
-// state with the given error code. This method must only be called from the
-// protocol goroutine.
+// resetConnectionLocked puts the endpoint in an error state with the given
+// error code and sends a RST if and only if the error is not ErrConnectionReset
+// indicating that the connection is being reset due to receiving a RST. This
+// method must only be called from the protocol goroutine.
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ // Only send a reset if the connection is being aborted for a reason
+ // other than receiving a reset.
e.state = StateError
- e.hardError = err
+ e.HardError = err
+ if err != tcpip.ErrConnectionReset {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ }
}
// completeWorkerLocked is called by the worker goroutine when it's about to
@@ -806,7 +838,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
- return tcpip.ErrConnectionReset
+ return tcpip.ErrTimeout
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
@@ -893,7 +925,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.mu.Lock()
e.state = StateError
- e.hardError = err
+ e.HardError = err
// Lock released below.
epilogue()
@@ -1068,6 +1100,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.workMu.Lock()
if err := funcs[v].f(); err != nil {
e.mu.Lock()
+ // Ensure we release all endpoint registration and route
+ // references as the connection is now in an error
+ // state.
+ e.workerCleanup = true
e.resetConnectionLocked(err)
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index c54610a87..dfaa4a559 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
}
}
-func testV4Connect(t *testing.T, c *context.Context) {
+func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
@@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) {
// Receive SYN packet.
b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv4(t, b, synCheckers...)
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) {
})
// Receive ACK packet.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv4(t, c.GetPacket(), ackCheckers...)
// Wait for connection to be established.
select {
@@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
testV4Connect(t, c)
}
-func testV6Connect(t *testing.T, c *context.Context) {
+func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
// Start connection attempt to IPv6 address.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
@@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) {
// Receive SYN packet.
b := c.GetV6Packet()
- checker.IPv6(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ synCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ))
+ checker.IPv6(t, b, synCheckers...)
tcp := header.TCP(header.IPv6(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) {
})
// Receive ACK packet.
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
+ ackCheckers := append(checkers, checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ))
+ checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
// Wait for connection to be established.
select {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ac927569a..a1b784b49 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -170,6 +172,101 @@ type rcvBufAutoTuneParams struct {
disabled bool
}
+// ReceiveErrors collect segment receive errors within transport layer.
+type ReceiveErrors struct {
+ tcpip.ReceiveErrors
+
+ // SegmentQueueDropped is the number of segments dropped due to
+ // a full segment queue.
+ SegmentQueueDropped tcpip.StatCounter
+
+ // ChecksumErrors is the number of segments dropped due to bad checksums.
+ ChecksumErrors tcpip.StatCounter
+
+ // ListenOverflowSynDrop is the number of times the listen queue overflowed
+ // and a SYN was dropped.
+ ListenOverflowSynDrop tcpip.StatCounter
+
+ // ListenOverflowAckDrop is the number of times the final ACK
+ // in the handshake was dropped due to overflow.
+ ListenOverflowAckDrop tcpip.StatCounter
+
+ // ZeroRcvWindowState is the number of times we advertised
+ // a zero receive window when rcvList is full.
+ ZeroRcvWindowState tcpip.StatCounter
+}
+
+// SendErrors collect segment send errors within the transport layer.
+type SendErrors struct {
+ tcpip.SendErrors
+
+ // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent
+ // to the network endpoint.
+ SegmentSendToNetworkFailed tcpip.StatCounter
+
+ // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent
+ // to the network endpoint.
+ SynSendToNetworkFailed tcpip.StatCounter
+
+ // Retransmits is the number of TCP segments retransmitted.
+ Retransmits tcpip.StatCounter
+
+ // FastRetransmit is the number of segments retransmitted in fast
+ // recovery.
+ FastRetransmit tcpip.StatCounter
+
+ // Timeouts is the number of times the RTO expired.
+ Timeouts tcpip.StatCounter
+}
+
+// Stats holds statistics about the endpoint.
+type Stats struct {
+ // SegmentsReceived is the number of TCP segments received that
+ // the transport layer successfully parsed.
+ SegmentsReceived tcpip.StatCounter
+
+ // SegmentsSent is the number of TCP segments sent.
+ SegmentsSent tcpip.StatCounter
+
+ // FailedConnectionAttempts is the number of times we saw Connect and
+ // Accept errors.
+ FailedConnectionAttempts tcpip.StatCounter
+
+ // ReceiveErrors collects segment receive errors within the
+ // transport layer.
+ ReceiveErrors ReceiveErrors
+
+ // ReadErrors collects segment read errors from an endpoint read call.
+ ReadErrors tcpip.ReadErrors
+
+ // SendErrors collects segment send errors within the transport layer.
+ SendErrors SendErrors
+
+ // WriteErrors collects segment write errors from an endpoint write call.
+ WriteErrors tcpip.WriteErrors
+}
+
+// IsEndpointStats is an empty method to implement the tcpip.EndpointStats
+// marker interface.
+func (*Stats) IsEndpointStats() {}
+
+// EndpointInfo holds useful information about a transport endpoint which
+// can be queried by monitoring tools.
+//
+// +stateify savable
+type EndpointInfo struct {
+ stack.TransportEndpointInfo
+
+ // HardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. HardError is protected by endpoint mu.
+ HardError *tcpip.Error `state:".(string)"`
+}
+
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*EndpointInfo) IsEndpointInfo() {}
+
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -178,6 +275,8 @@ type rcvBufAutoTuneParams struct {
//
// +stateify savable
type endpoint struct {
+ EndpointInfo
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -186,8 +285,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
- stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
+ stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
// lastError represents the last error that the endpoint reported;
@@ -218,7 +316,6 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
- id stack.TransportEndpointID
state EndpointState `state:".(EndpointState)"`
@@ -226,6 +323,7 @@ type endpoint struct {
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
route stack.Route `state:"manual"`
+ ttl uint8
v6only bool
isConnectNotified bool
// TCP should never broadcast but Linux nevertheless supports enabling/
@@ -240,11 +338,6 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
- // hardError is meaningful only when state is stateError, it stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. hardError is protected by mu.
- hardError *tcpip.Error `state:".(string)"`
-
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -280,6 +373,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -393,13 +489,19 @@ type endpoint struct {
probe stack.TCPProbeFunc `state:"nosave"`
// The following are only used to assist the restore run to re-connect.
- bindAddress tcpip.Address
connectingAddress tcpip.Address
// amss is the advertised MSS to the peer by this endpoint.
amss uint16
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
gso *stack.GSO
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats Stats `state:"nosave"`
}
// StopWork halts packet processing. Only to be used in tests.
@@ -427,10 +529,15 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: stack,
- netProto: netProto,
+ stack: s,
+ EndpointInfo: EndpointInfo{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
+ },
waiterQueue: waiterQueue,
state: StateInitial,
rcvBufSize: DefaultReceiveBufferSize,
@@ -446,26 +553,26 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
}
var ss SendBufferSizeOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
var rs ReceiveBufferSizeOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
var cs tcpip.CongestionControlOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
var mrb tcpip.ModerateReceiveBufferOption
- if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
+ if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
e.rcvAutoParams.disabled = !bool(mrb)
}
- if p := stack.GetTCPProbe(); p != nil {
+ if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -564,11 +671,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +732,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -731,11 +838,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
- he := e.hardError
+ he := e.HardError
e.mu.RUnlock()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
@@ -744,6 +852,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
e.mu.RUnlock()
+ if err == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
return v, tcpip.ControlMessages{}, err
}
@@ -787,7 +898,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, e.hardError
+ return 0, e.HardError
default:
return 0, tcpip.ErrClosedForSend
}
@@ -806,7 +917,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -818,50 +929,57 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
if err != nil {
e.sndBufMu.Unlock()
e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
-
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
}
- // Copy in memory without holding sndBufMu so that worker goroutine can
- // make progress independent of this operation.
- v, perr := p.Get(avail)
- if perr != nil {
+ // Fetch data.
+ v, perr := p.Payload(avail)
+ if perr != nil || len(v) == 0 {
+ if opts.Atomic { // See above.
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ }
+ // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a
- // write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- return 0, nil, err
- }
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
- // Discard any excess data copied in due to avail being reduced due to a
- // simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
// Add data to the send queue.
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
- e.sndBufUsed += l
- e.sndBufInQueue += seqnum.Size(l)
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
-
e.sndBufMu.Unlock()
// Release the endpoint lock to prevent deadlocks due to lock
// order inversion when acquiring workMu.
@@ -875,7 +993,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return int64(l), nil, nil
+
+ return int64(len(v)), nil, nil
}
// Peek reads data without consuming it from the endpoint.
@@ -889,8 +1008,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.state; !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.hardError
+ return 0, tcpip.ControlMessages{}, e.HardError
}
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
}
@@ -899,6 +1019,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.state.connected() {
+ e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
@@ -946,62 +1067,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1065,9 +1133,87 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
@@ -1082,6 +1228,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.v6only = v != 0
return nil
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+ return nil
+
case tcpip.KeepaliveEnabledOption:
e.keepalive.Lock()
e.keepalive.enabled = v != 0
@@ -1150,6 +1302,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Linux returns ENOENT when an invalid congestion
// control algorithm is specified.
return tcpip.ErrNoSuchFile
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1176,6 +1345,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1198,18 +1379,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1246,6 +1415,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1255,7 +1434,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
@@ -1269,6 +1448,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.TTLOption:
+ e.mu.Lock()
+ *o = tcpip.TTLOption(e.ttl)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.mu.RLock()
@@ -1333,13 +1518,25 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ *o = tcpip.IPv4TOSOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
@@ -1355,7 +1552,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1369,7 +1566,12 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- return e.connect(addr, true, true)
+ err := e.connect(addr, true, true)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
}
// connect connects the endpoint to its peer. In the normal non-S/R case, the
@@ -1378,14 +1580,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// created (so no new handshaking is done); for stack-accepted connections not
// yet accepted by the app, they are restored without running the main goroutine
// here.
-func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- defer func() {
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- }
- }()
connectingAddr := addr.Addr
@@ -1430,29 +1627,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.hardError
+ return e.HardError
default:
return tcpip.ErrInvalidEndpointState
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
defer r.Release()
- origID := e.id
+ origID := e.ID
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- e.id.LocalAddress = r.LocalAddress
- e.id.RemoteAddress = r.RemoteAddress
- e.id.RemotePort = addr.Port
+ e.ID.LocalAddress = r.LocalAddress
+ e.ID.RemoteAddress = r.RemoteAddress
+ e.ID.RemotePort = addr.Port
- if e.id.LocalPort != 0 {
+ if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1461,20 +1658,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// one. Make sure that it isn't one that will result in the same
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
- sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- if sameAddr && p == e.id.RemotePort {
+ sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.PortSeed())
+ h.Write([]byte(e.ID.LocalAddress))
+ h.Write([]byte(e.ID.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
- id := e.id
+ id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
- e.id = id
+ e.ID = id
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@@ -1490,7 +1702,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1509,7 +1721,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
e.segmentQueue.mu.Lock()
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
- s.id = e.id
+ s.id = e.ID
s.route = r.Clone()
e.sndWaker.Assert()
}
@@ -1569,7 +1781,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.id, nil)
+ s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
@@ -1597,14 +1809,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
-func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ err := e.listen(backlog)
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
+ return err
+}
+
+func (e *endpoint) listen(backlog int) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- defer func() {
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- }
- }()
// Allow the backlog to be adjusted if the endpoint is not shutting down.
// When the endpoint shuts down, it sets workerCleanup to true, and from
@@ -1630,11 +1846,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
// Endpoint must be bound before it can transition to listen mode.
if e.state != StateBound {
+ e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1698,7 +1915,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
return tcpip.ErrAlreadyBound
}
- e.bindAddress = addr.Addr
+ e.BindAddr = addr.Addr
netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
@@ -1715,26 +1932,26 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
e.isPortReserved = true
e.effectiveNetProtos = netProtos
- e.id.LocalPort = port
+ e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
- e.id.LocalPort = 0
- e.id.LocalAddress = ""
+ e.ID.LocalPort = 0
+ e.ID.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
@@ -1745,7 +1962,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.boundNICID = nic
- e.id.LocalAddress = addr.Addr
+ e.ID.LocalAddress = addr.Addr
}
// Mark endpoint as bound.
@@ -1760,8 +1977,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
NIC: e.boundNICID,
}, nil
}
@@ -1776,8 +1993,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
return tcpip.FullAddress{
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
NIC: e.boundNICID,
}, nil
}
@@ -1789,6 +2006,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
s.decRef()
return
}
@@ -1796,11 +2014,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if !s.csumValid {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
s.decRef()
return
}
e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ e.stats.SegmentsReceived.Increment()
if (s.flags & header.TCPFlagRst) != 0 {
e.stack.Stats().TCP.ResetsReceived.Increment()
}
@@ -1811,6 +2031,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
} else {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
s.decRef()
}
}
@@ -1860,6 +2081,7 @@ func (e *endpoint) readyToRead(s *segment) {
// that a subsequent read of the segment will correctly trigger
// a non-zero notification.
if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
e.zeroWindow = true
}
e.rcvList.PushBack(s)
@@ -2012,7 +2234,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Copy EndpointID.
e.mu.Lock()
- s.ID = stack.TCPEndpointID(e.id)
+ s.ID = stack.TCPEndpointID(e.ID)
e.mu.Unlock()
// Copy endpoint rcv state.
@@ -2119,7 +2341,7 @@ func (e *endpoint) initGSO() {
gso.Type = stack.GSOTCPv6
gso.L3HdrLen = header.IPv6MinimumSize
default:
- panic(fmt.Sprintf("Unknown netProto: %v", e.netProto))
+ panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto))
}
gso.NeedsCsum = true
gso.CsumOffset = header.TCPChecksumOffset
@@ -2135,6 +2357,20 @@ func (e *endpoint) State() uint32 {
return uint32(e.state)
}
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.EndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
+}
+
func mssForRoute(r *stack.Route) uint16 {
return uint16(r.MTU() - header.TCPMinimumSize)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 831389ec7..eae17237e 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() {
case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
@@ -190,10 +190,10 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind := func() {
e.state = StateInitial
- if len(e.bindAddress) == 0 {
- e.bindAddress = e.id.LocalAddress
+ if len(e.BindAddr) == 0 {
+ e.BindAddr = e.ID.LocalAddress
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
panic("endpoint binding failed: " + err.String())
}
}
@@ -202,19 +202,19 @@ func (e *endpoint) Resume(s *stack.Stack) {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
bind()
if len(e.connectingAddress) == 0 {
- e.connectingAddress = e.id.RemoteAddress
+ e.connectingAddress = e.ID.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
- if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
connectedLoading.Done()
@@ -236,7 +236,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
connectingLoading.Done()
@@ -288,21 +288,21 @@ func (e *endpoint) loadLastError(s string) {
}
// saveHardError is invoked by stateify.
-func (e *endpoint) saveHardError() string {
- if e.hardError == nil {
+func (e *EndpointInfo) saveHardError() string {
+ if e.HardError == nil {
return ""
}
- return e.hardError.String()
+ return e.HardError.String()
}
// loadHardError is invoked by stateify.
-func (e *endpoint) loadHardError(s string) {
+func (e *EndpointInfo) loadHardError(s string) {
if s == "" {
return
}
- e.hardError = loadError(s)
+ e.HardError = loadError(s)
}
var messageToError map[string]*tcpip.Error
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index ee04dcfcc..db40785d3 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -129,7 +126,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
@@ -156,7 +153,7 @@ func replyWithReset(s *segment) {
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
// SetOption implements TransportProtocol.SetOption.
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 1f9b1e0ef..8332a0179 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -417,6 +417,7 @@ func (s *sender) resendSegment() {
s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
s.sendSegment(seg)
s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+ s.ep.stats.SendErrors.FastRetransmit.Increment()
// Run SetPipe() as per RFC 6675 section 5 Step 4.4
s.SetPipe()
@@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool {
}
s.ep.stack.Stats().TCP.Timeouts.Increment()
+ s.ep.stats.SendErrors.Timeouts.Increment()
// Give up if we've waited more than a minute since the last resend.
if s.rto >= 60*time.Second {
@@ -664,7 +666,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
@@ -1181,6 +1190,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
func (s *sender) sendSegment(seg *segment) *tcpip.Error {
if !seg.xmitTime.IsZero() {
s.ep.stack.Stats().TCP.Retransmits.Increment()
+ s.ep.stats.SendErrors.Retransmits.Increment()
if s.sndCwnd < s.sndSsthresh {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..782d7b42c 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) {
t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ }
+
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 4e7f1a740..afea124ec 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) {
t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
+ t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ }
+
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
}
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
+ t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ }
+
c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
// Acknowledge all pending data to recover point.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index f79b8ec5f..6d022a266 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,9 +97,12 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
}
}
@@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
+ t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ }
}
func TestTCPSegmentsSentIncrement(t *testing.T) {
@@ -131,11 +137,14 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
+ t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ }
}
func TestTCPResetsSentIncrement(t *testing.T) {
@@ -190,21 +199,122 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
+// a RST if an ACK is received on the listening socket for which there is no
+// active handshake in progress and we are not using SYN cookies.
+func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ c.GetPacket()
+
+ // Now resend the same ACK, this ACK should generate a RST as there
+ // should be no endpoint in SYN-RCVD state and we are not using
+ // syn-cookies yet. The reason we send the same ACK is we need a valid
+ // cookie(IRS) generated by the netstack without which the ACK will be
+ // rejected.
+ c.SendPacket(nil, ackHeaders)
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+}
+
func TestTCPResetsReceivedIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- ackNum := seqnum.Value(789)
+ iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(ackNum, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: c.IRS.Add(2),
- AckNum: ackNum.Add(2),
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
RcvWnd: rcvWnd,
Flags: header.TCPFlagRst,
})
@@ -214,18 +324,43 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
}
}
+func TestTCPResetsDoNotGenerateResets(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ Flags: header.TCPFlagRst,
+ })
+
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ }
+ c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
+}
+
func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -241,7 +376,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -291,7 +426,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -339,11 +474,172 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestTOSV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ c.EP = ep
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+ t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ }
+
+ var v tcpip.IPv4TOSOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv4TOSOption(tos); v != want {
+ t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testV4Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestTrafficClassV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ const tos = 0xC0
+ if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+ t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+ }
+
+ var v tcpip.IPv6TrafficClassOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+ t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c, checker.TOS(tos, 0))
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b := c.GetV6Packet()
+ checker.IPv6(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ checker.TOS(tos, 0),
+ )
+
+ if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+}
+
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -431,8 +727,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -505,7 +800,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -574,7 +869,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -659,7 +954,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -672,14 +967,17 @@ func TestShutdownRead(t *testing.T) {
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
}
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ }
}
func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -729,6 +1027,11 @@ func TestFullWindowReceive(t *testing.T) {
t.Fatalf("got data = %v, want = %v", v, data)
}
+ var want uint64 = 1
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ }
+
// Check that we get an ACK for the newly non-zero window.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -746,11 +1049,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -850,7 +1151,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -891,7 +1192,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -949,8 +1250,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -984,8 +1284,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1025,7 +1324,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1098,7 +1397,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1167,8 +1466,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1273,7 +1571,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1323,7 +1621,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1371,7 +1669,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1453,7 +1751,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1569,16 +1867,44 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
+func TestSetTTL(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := context.New(t, 65535)
+ defer c.Cleanup()
+
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
+ t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+
+ checker.IPv4(t, b, checker.TTL(wantTTL))
+ })
+ }
+}
+
func TestActiveSendMSSLessThanMTU(t *testing.T) {
const maxPayload = 100
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1601,7 +1927,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1745,7 +2071,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1847,7 +2173,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1878,13 +2204,20 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
}
}
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ }
+ if tcp.EndpointState(c.EP.State()) != tcp.StateError {
+ t.Fatalf("got EP state is not StateError")
+ }
}
func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1909,7 +2242,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -1952,7 +2285,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2006,7 +2339,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2077,7 +2410,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2165,7 +2498,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2251,7 +2584,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2383,7 +2716,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2433,7 +2766,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2449,12 +2782,23 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
+ t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ }
+ // Ensure there were no errors during handshake. If these stats have
+ // incremented, then the connection should not have been established.
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ }
+ if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ }
}
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2473,12 +2817,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
}
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2499,6 +2846,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
if got := stats.TCP.ChecksumErrors.Value(); got != want {
t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
+ }
}
func TestReceivedSegmentQueuing(t *testing.T) {
@@ -2509,7 +2859,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2555,7 +2905,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2730,8 +3080,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2743,8 +3093,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2754,7 +3104,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
}
func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2800,7 +3153,10 @@ func TestDefaultBufferSizes(t *testing.T) {
}
func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2819,37 +3175,96 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New([]string{
- ipv4.ProtocolName,
- ipv6.ProtocolName,
- }, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
id := loopback.New()
if testing.Verbose() {
@@ -3105,7 +3520,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3182,7 +3597,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3356,7 +3771,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
@@ -3459,8 +3874,8 @@ func TestKeepalive(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
}
@@ -3886,6 +4301,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
}
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ }
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -3924,6 +4342,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ // Expect InvalidEndpointState errors on a read at this point.
+ if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ }
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
+ t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ }
+
if err := ep.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 272481aa0..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -137,7 +137,10 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
@@ -150,11 +153,19 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- id, linkEP := channel.New(1000, mtu, "")
+ ep := channel.New(1000, mtu, "")
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -180,7 +191,7 @@ func New(t *testing.T, mtu uint32) *Context {
return &Context{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -267,7 +278,7 @@ func (c *Context) GetPacketNonBlocking() []byte {
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2))
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
buf = buf[:maxTotalSize]
}
@@ -286,9 +297,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
icmp.SetType(typ)
icmp.SetCode(code)
-
- copy(icmp[header.ICMPv4PayloadOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2)
+ const icmpv4VariableHeaderOffset = 4
+ copy(icmp[icmpv4VariableHeaderOffset:], p1)
+ copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
@@ -511,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -584,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -597,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 4bec48c0f..43fcc27f0 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index ac2666f69..c9460aa0d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "udp_packet_list",
@@ -50,6 +51,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index ac5905772..6e87245b7 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "math"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -37,15 +36,35 @@ type udpPacket struct {
views [8]buffer.View `state:"nosave"`
}
-type endpointState int
+// EndpointState represents the state of a UDP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnected
+ StateClosed
)
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnected:
+ return "CONNECTING"
+ case StateClosed:
+ return "CLOSED"
+ default:
+ return "UNKNOWN"
+ }
+}
+
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -55,10 +74,11 @@ const (
//
// +stateify savable
type endpoint struct {
+ stack.TransportEndpointInfo
+
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
@@ -73,20 +93,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- regNICID tcpip.NICID
+ state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
v6only bool
+ ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -101,6 +124,9 @@ type endpoint struct {
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
+
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
}
// +stateify savable
@@ -109,10 +135,13 @@ type multicastMembership struct {
multicastAddr tcpip.Address
}
-func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
- stack: stack,
- netProto: netProto,
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
// requests.
@@ -130,6 +159,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ state: StateInitial,
}
}
@@ -140,13 +170,13 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
- case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ case StateBound, StateConnected:
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
e.multicastMemberships = nil
@@ -163,7 +193,7 @@ func (e *endpoint) Close() {
e.route.Release()
// Update the state.
- e.state = stateClosed
+ e.state = StateClosed
e.mu.Unlock()
@@ -186,6 +216,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
+ e.stats.ReadErrors.ReadClosed.Increment()
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
@@ -211,11 +242,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
- case stateInitial:
- case stateConnected:
+ case StateInitial:
+ case StateConnected:
return false, nil
- case stateBound:
+ case StateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
@@ -232,7 +263,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return true, nil
}
@@ -248,7 +279,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
- localAddr := e.id.LocalAddress
+ localAddr := e.ID.LocalAddress
if isBroadcastOrMulticast(localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
@@ -273,17 +304,35 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ n, ch, err := e.write(p, opts)
+ switch err {
+ case nil:
+ e.stats.PacketsSent.Increment()
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ e.stats.WriteErrors.InvalidArgs.Increment()
+ case tcpip.ErrClosedForSend:
+ e.stats.WriteErrors.WriteClosed.Increment()
+ case tcpip.ErrInvalidEndpointState:
+ e.stats.WriteErrors.InvalidEndpointState.Increment()
+ case tcpip.ErrNoLinkAddress:
+ e.stats.SendErrors.NoLinkAddr.Increment()
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ // Errors indicating any problem with IP routing of the packet.
+ e.stats.SendErrors.NoRoute.Increment()
+ default:
+ // For all other errors when writing to the network layer.
+ e.stats.SendErrors.SendToNetworkFailed.Increment()
+ }
+ return n, ch, err
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- if p.Size() > math.MaxUint16 {
- // Payload can't possibly fit in a packet.
- return 0, nil, tcpip.ErrMessageTooLong
- }
-
to := opts.To
e.mu.RLock()
@@ -322,7 +371,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
- if e.state != stateConnected {
+ if e.state != StateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
@@ -330,12 +379,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
- if e.bindNICID != 0 {
- if nicid != 0 && nicid != e.bindNICID {
+ if e.BindNICID != 0 {
+ if nicid != 0 && nicid != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
}
if to.Addr == header.IPv4Broadcast && !e.broadcast {
@@ -366,17 +415,25 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
+ ttl := e.ttl
+ useDefaultTTL := ttl == 0
- ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
ttl = e.multicastTTL
+ // Multicast allows a 0 TTL.
+ useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -387,12 +444,17 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
@@ -400,12 +462,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
e.v6only = v != 0
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -440,7 +507,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
- if e.bindNICID != 0 && e.bindNICID != nic {
+ if e.BindNICID != 0 && e.BindNICID != nic {
return tcpip.ErrInvalidEndpointState
}
@@ -467,7 +534,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
} else {
- nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrUnknownDevice
@@ -484,7 +551,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
- if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -505,7 +572,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
} else {
- nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrUnknownDevice
@@ -527,7 +594,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -544,12 +611,39 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
e.mu.Unlock()
return nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ return nil
}
return nil
}
@@ -566,7 +660,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -576,21 +683,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
- if e.netProto != header.IPv6ProtocolNumber {
+ if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
@@ -604,6 +699,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.TTLOption:
+ e.mu.Lock()
+ *o = tcpip.TTLOption(e.ttl)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.MulticastTTLOption:
e.mu.Lock()
*o = tcpip.MulticastTTLOption(e.multicastTTL)
@@ -638,6 +739,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -653,6 +764,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ *o = tcpip.IPv4TOSOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -660,7 +783,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -683,14 +806,21 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
+ if useDefaultTTL {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ r.Stats().UDP.PacketSendErrors.Increment()
+ return err
+ }
+
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
-
- return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl)
+ return nil
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.netProto
+ netProto := e.NetProto
if len(addr.Addr) == 0 {
return netProto, nil
}
@@ -707,14 +837,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
}
// Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ if !allowMismatch && len(e.ID.LocalAddress) == 16 {
return 0, tcpip.ErrNetworkUnreachable
}
}
// Fail if we're bound to an address length different from the one we're
// checking.
- if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -726,28 +856,32 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return nil
}
id := stack.TransportEndpointID{}
// Exclude ephemerally bound endpoints.
- if e.bindNICID != 0 || e.id.LocalAddress == "" {
+ if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.id.LocalPort,
- LocalAddress: e.id.LocalAddress,
+ LocalPort: e.ID.LocalPort,
+ LocalAddress: e.ID.LocalAddress,
}
- id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.state = stateBound
+ e.state = StateBound
} else {
- e.state = stateInitial
+ if e.ID.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ }
+ e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.id = id
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.ID = id
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@@ -772,18 +906,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicid := addr.NIC
var localPort uint16
switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.id.LocalPort
- if e.bindNICID == 0 {
+ case StateInitial:
+ case StateBound, StateConnected:
+ localPort = e.ID.LocalPort
+ if e.BindNICID == 0 {
break
}
- if nicid != 0 && nicid != e.bindNICID {
+ if nicid != 0 && nicid != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.bindNICID
+ nicid = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -795,13 +929,13 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
id := stack.TransportEndpointID{
- LocalAddress: e.id.LocalAddress,
+ LocalAddress: e.ID.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
- if e.state == stateInitial {
+ if e.state == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -822,17 +956,17 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Remove the old registration.
- if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ if e.ID.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
}
- e.id = id
+ e.ID = id
e.route = r.Clone()
e.dstPort = addr.Port
- e.regNICID = nicid
+ e.RegisterNICID = nicid
e.effectiveNetProtos = netProtos
- e.state = stateConnected
+ e.state = StateConnected
e.rcvMu.Lock()
e.rcvReady = true
@@ -854,7 +988,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -885,17 +1019,17 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
- if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ if e.ID.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
@@ -903,7 +1037,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -941,12 +1075,12 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return err
}
- e.id = id
- e.regNICID = nicid
+ e.ID = id
+ e.RegisterNICID = nicid
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
e.rcvMu.Lock()
e.rcvReady = true
@@ -967,7 +1101,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the effective NICID generated by bindLocked.
- e.bindNICID = e.regNICID
+ e.BindNICID = e.RegisterNICID
return nil
}
@@ -978,9 +1112,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.LocalAddress,
- Port: e.id.LocalPort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
}, nil
}
@@ -989,14 +1123,14 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
- NIC: e.regNICID,
- Addr: e.id.RemoteAddress,
- Port: e.id.RemotePort,
+ NIC: e.RegisterNICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
}, nil
}
@@ -1026,6 +1160,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
if int(hdr.Length()) > vv.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
@@ -1033,11 +1168,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
+ e.stats.PacketsReceived.Increment()
// Drop the packet if our buffer is currently full.
- if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ if !e.rcvReady || e.rcvClosed {
+ e.rcvMu.Unlock()
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
}
@@ -1069,10 +1213,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
-// State implements socket.Socket.State.
+// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- // TODO(b/112063468): Translate internal state to values returned by Linux.
- return 0
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
+}
+
+// Info returns a copy of the endpoint info.
+func (e *endpoint) Info() tcpip.EndpointInfo {
+ e.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := e.TransportEndpointInfo
+ e.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (e *endpoint) Stats() tcpip.EndpointStats {
+ return &e.stats
}
func isBroadcastOrMulticast(a tcpip.Address) bool {
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 5cbb56120..b227e353b 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -72,12 +72,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
}
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return
}
@@ -92,14 +92,14 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if e.state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(err)
}
- } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
@@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
// Our saved state had a port, but we don't actually have a
// reservation. We need to remove the port from our state, but still
// pass it to the reservation machinery.
- id := e.id
- e.id.LocalPort = 0
- e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ id := e.ID
+ e.ID.LocalPort = 0
+ e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a874fc9fd..d399ec722 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,17 +74,17 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
- ep.id = r.id
+ ep.ID = r.id
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
- ep.regNICID = r.route.NICID()
+ ep.RegisterNICID = r.route.NICID()
- ep.state = stateConnected
+ ep.state = StateConnected
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f76e7fbe1..de026880f 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -69,7 +66,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
+
+ // Only send ICMP error if the address is not a multicast/broadcast
+ // v4/v6 address or the source is not the unspecified address.
+ //
+ // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
+ return true
+ }
+
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ //
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported; or
+ //
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ switch len(id.LocalAddress) {
+ case header.IPv4AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
+ return true
+ }
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 where it mentioned that at least 8
+ // bytes of the payload must be included. Today linux and other
+ // systems implement the] RFC1812 definition and not the original
+ // RFC 1122 requirement.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4DstUnreachable)
+ pkt.SetCode(header.ICMPv4PortUnreachable)
+ pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+
+ case header.IPv6AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
+ return true
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
+ pkt.SetType(header.ICMPv6DstUnreachable)
+ pkt.SetCode(header.ICMPv6PortUnreachable)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ }
return true
}
@@ -83,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 9da6edce2..b724d788c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -274,13 +274,17 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
- id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -380,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h)
+ c.injectV4Packet(payload, &h, true /* valid */)
} else {
- c.injectV6Packet(payload, &h)
+ c.injectV6Packet(payload, &h, true /* valid */)
}
}
// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -405,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
+ l := uint16(header.UDPMinimumSize + len(payload))
+ if !valid {
+ // Change the UDP payload length to corrupt the header
+ // as requested by the caller.
+ l++
+ }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
+ Length: l,
})
// Calculate the UDP pseudo-header checksum.
@@ -422,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-// injectV6Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
+// injectV4Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -461,101 +475,78 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
}
func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
+ return newMinPayload(30)
+}
+
+func newMinPayload(minSize int) []byte {
+ b := make([]byte, minSize+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
-
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
-
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}
-// testRead sends a packet of the given test flow into the stack by injecting it
-// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// testReadInternal sends a packet of the given test flow into the stack by
+// injecting it into the link endpoint. It then attempts to read it from the
+// UDP endpoint and depending on if this was expected to succeed verifies its
+// correctness.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
c.t.Helper()
payload := newPayload()
@@ -566,6 +557,9 @@ func testRead(c *testContext, flow testFlow) {
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+
var addr tcpip.FullAddress
v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
@@ -573,25 +567,55 @@ func testRead(c *testContext, flow testFlow) {
select {
case <-ch:
v, _, err = c.ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for data")
+ case <-time.After(300 * time.Millisecond):
+ if packetShouldBeDropped {
+ return // expected to time out
+ }
+ c.t.Fatal("timed out waiting for data")
}
}
+ if expectReadError && err != nil {
+ c.checkEndpointReadStats(1, epstats, err)
+ return
+ }
+
+ if err != nil {
+ c.t.Fatal("Read failed:", err)
+ }
+
+ if packetShouldBeDropped {
+ c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ }
+
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr)
+ c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
- c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
+ c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+ c.checkEndpointReadStats(1, epstats, err)
+}
+
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness.
+func testRead(c *testContext, flow testFlow) {
+ c.t.Helper()
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+}
+
+// testFailingRead sends a packet of the given test flow into the stack by
+// injecting it into the link endpoint. It then tries to read it from the UDP
+// endpoint and expects this to fail.
+func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
+ c.t.Helper()
+ testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
}
func TestBindEphemeralPort(t *testing.T) {
@@ -763,13 +787,17 @@ func TestReadOnBoundToMulticast(t *testing.T) {
c.t.Fatal("SetSockOpt failed:", err)
}
+ // Check that we receive multicast packets but not unicast or broadcast
+ // ones.
testRead(c, flow)
+ testFailingRead(c, broadcast, false /* expectReadError */)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
})
}
}
// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
-// address and receive broadcast data on it.
+// address and can receive only broadcast data.
func TestV4ReadOnBoundToBroadcast(t *testing.T) {
for _, flow := range []testFlow{broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -784,8 +812,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
c.t.Fatalf("Bind failed: %s", err)
}
- // Test acceptance.
+ // Check that we receive broadcast packets but not unicast ones.
testRead(c, flow)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
+ })
+ }
+}
+
+// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
+// and receive broadcast and unicast data.
+func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s (", err)
+ }
+
+ // Check that we receive both broadcast and unicast packets.
+ testRead(c, flow)
+ testRead(c, unicastV4)
})
}
}
@@ -794,7 +845,8 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
// and verifies it fails with the provided error code.
func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
c.t.Helper()
-
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
@@ -802,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
_, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
+ c.checkEndpointWriteStats(1, epstats, gotErr)
if gotErr != wantErr {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
@@ -827,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec
func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
writeOpts := tcpip.WriteOptions{}
if setDest {
@@ -844,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
-
+ c.checkEndpointWriteStats(1, epstats, err)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -913,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Write to V4 mapped address.
testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ const want = 1
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
+ c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
+ }
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
@@ -1175,6 +1234,109 @@ func TestTTL(t *testing.T) {
}
}
+func TestSetTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
+ }
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ep.Close()
+
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+ })
+ }
+}
+
+func TestTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = 0xC0
+ var v tcpip.IPv4TOSOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+ c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv4TOSOption(tos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
+func TestTOSV6(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = 0xC0
+ var v tcpip.IPv6TrafficClassOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+ c.t.Errorf("SetSockOpt failed: %s", err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1238,3 +1400,267 @@ func TestMulticastInterfaceOption(t *testing.T) {
})
}
}
+
+// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV4UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true, will result in a payload large enough
+ // so that the final generated IPv4 packet is larger than
+ // header.IPv4MinimumProcessableDatagramSize.
+ largePayload bool
+ }{
+ {unicastV4, true, false},
+ {unicastV4, true, true},
+ {multicastV4, false, false},
+ {multicastV4, false, true},
+ {broadcast, false, false},
+ {broadcast, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(576)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
+
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV6UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true will result in a payload large enough to
+ // create an IPv6 packet > header.IPv6MinimumMTU bytes.
+ largePayload bool
+ }{
+ {unicastV6, true, false},
+ {unicastV6, true, true},
+ {multicastV6, false, false},
+ {multicastV6, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(1280)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestIncrementMalformedPacketsReceived verifies if the malformed received
+// global and endpoint stats get incremented.
+func TestIncrementMalformedPacketsReceived(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ payload := newPayload()
+ c.t.Helper()
+ h := unicastV6.header4Tuple(incoming)
+ c.injectV6Packet(payload, &h, false /* !valid */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownRead verifies endpoint read shutdown and error
+// stats increment on packet receive.
+func TestShutdownRead(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingRead(c, unicastV6, true /* expectReadError */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownWrite verifies endpoint write shutdown and error
+// stats increment on packet write.
+func TestShutdownWrite(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+}
+
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil:
+ want.PacketsSent.IncrementBy(incr)
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ want.WriteErrors.InvalidArgs.IncrementBy(incr)
+ case tcpip.ErrClosedForSend:
+ want.WriteErrors.WriteClosed.IncrementBy(incr)
+ case tcpip.ErrInvalidEndpointState:
+ want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
+ case tcpip.ErrNoLinkAddress:
+ want.SendErrors.NoLinkAddr.IncrementBy(incr)
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ want.SendErrors.NoRoute.IncrementBy(incr)
+ default:
+ want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
+
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil, tcpip.ErrWouldBlock:
+ case tcpip.ErrClosedForReceive:
+ want.ReadErrors.ReadClosed.IncrementBy(incr)
+ default:
+ c.t.Errorf("Endpoint error missing stats update err %v", err)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}