summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go21
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go97
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go12
-rw-r--r--pkg/tcpip/transport/tcp/snd.go35
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go210
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go120
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go39
12 files changed, 339 insertions, 254 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 74fe19e98..d1e4a7cb7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -504,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
@@ -519,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
e.ID = id
- e.route = r.Clone()
+ e.route = r
e.RegisterNICID = nicID
e.state = stateConnected
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 9faab4b9e..e5e247342 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -366,6 +366,13 @@ func (ep *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+ ep.lastErrorMu.Lock()
+ ep.lastError = err
+ ep.lastErrorMu.Unlock()
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eee3f11c1..7befcfc9b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
e.mu.RLock()
+ defer e.mu.RUnlock()
if e.closed {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
@@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if e.route.IsResolutionRequired() {
- savedRoute := e.route
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in finishWrite.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Make sure that the route didn't change during the
- // time we didn't hold the lock.
- if !e.connected || savedRoute != e.route {
- e.mu.Unlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
- n, ch, err := e.finishWrite(payloadBytes, savedRoute)
- e.mu.Unlock()
- return n, ch, err
- }
-
- n, ch, err := e.finishWrite(payloadBytes, e.route)
- e.mu.RUnlock()
- return n, ch, err
+ return e.finishWrite(payloadBytes, e.route)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
@@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
- e.mu.RUnlock()
return n, ch, err
}
@@ -404,7 +377,7 @@ func (*endpoint) Disconnect() *tcpip.Error {
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrAddressFamilyNotSupported
}
e.mu.Lock()
@@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer route.Release()
if e.associated {
// Re-register the endpoint with the appropriate NIC.
if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ route.Release()
return err
}
e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
@@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the route we've connected via.
- e.route = route.Clone()
+ e.route = route
e.connected = true
return nil
@@ -620,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ e.mu.RLock()
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -632,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// sockets.
if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
@@ -639,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
@@ -650,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// If bound to a NIC, only accept data for that NIC.
if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
// If bound to an address, only accept data for that address.
if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
}
@@ -663,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// connected to.
if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
@@ -697,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3e1041cbe..2d96a65bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index c944dccc0..0dc710276 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error {
func (h *handshake) resolveRoute() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error {
// Initial action is to resolve route.
index := wakerForResolution
+ attemptedResolution := false
for {
switch index {
case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
- if err == tcpip.ErrNoLinkAddress {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- } else if err != nil {
+ if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
+ if err != nil {
h.ep.stats.SendErrors.NoRoute.Increment()
}
// Either success (err == nil) or failure.
return err
}
+ if attemptedResolution {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ return tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
// Resolution not completed. Keep trying...
case wakerForNotification:
n := h.ep.fetchNotifications()
if n&notifyClose != 0 {
- h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error {
// complete completes the TCP 3-way handshake initiated by h.start().
func (h *handshake) complete() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resendWaker := sleep.Waker{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Initialize the sleeper based on the wakers in funcs.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
for i := range funcs {
s.AddWaker(funcs[i].w, i)
}
@@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const notification = 2
const timeWaitDone = 3
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7a37c10bb..6e3c8860e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -502,9 +502,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // bindToDevice is set to the NIC on which to bind or disabled if 0.
- bindToDevice tcpip.NICID
-
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -1303,6 +1300,15 @@ func (e *endpoint) LastError() *tcpip.Error {
return e.lastErrorLocked()
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.LockUser()
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+ e.UnlockUser()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1812,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.LockUser()
- e.bindToDevice = id
- e.UnlockUser()
-
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(*v)
@@ -2004,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case *tcpip.BindToDeviceOption:
- e.LockUser()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.UnlockUser()
-
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.LockUser()
@@ -2211,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
}
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2253,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
if err == tcpip.ErrPortInUse {
return false, nil
}
@@ -2272,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// the selected port.
e.ID = id
e.isPortReserved = true
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
@@ -2283,7 +2280,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.isRegistered = true
e.setEndpointState(StateConnecting)
- e.route = r.Clone()
+ r.Acquire()
+ e.route = r
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -2624,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2635,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2644,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
// TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
e.boundNICID = nic
@@ -2708,6 +2707,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ // Linux passes the payload with the TCP header. We don't know if the TCP
+ // header even exists, it may not for fragmented packets.
+ Payload: pkt.Data.ToView(),
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.notifyProtocolGoroutine(notifyError)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
@@ -2722,16 +2756,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNoRoute
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNetworkUnreachable
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
@@ -2989,6 +3017,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index f2b1b68da..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// If we started off with a window larger than what can he held in
// the 16bit window field, we ceil the value to the max value.
- // While ceiling, we still do not want to grow the right edge when
- // not applicable.
if scaledWnd > math.MaxUint16 {
- if toGrow {
- scaledWnd = seqnum.Size(math.MaxUint16)
- } else {
- scaledWnd = seqnum.Size(uint16(scaledWnd))
- }
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
}
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index baec762e1..cc991aba6 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -137,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -372,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -394,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -401,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -629,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -1023,7 +1034,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding += s.pCount(seg)
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
s.writeNext = seg.Next()
}
@@ -1038,6 +1049,7 @@ func (s *sender) enterRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1207,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
@@ -1380,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1399,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 8eba0efeb..5922083a9 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -187,7 +187,6 @@ func (e *endpoint) StateFields() []string {
"shutdownFlags",
"sackPermitted",
"sack",
- "bindToDevice",
"delay",
"scoreboard",
"segmentQueue",
@@ -232,7 +231,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
var recentTSTimeValue unixTime = e.saveRecentTSTime()
stateSinkObject.SaveValue(26, recentTSTimeValue)
var acceptedChanValue []*endpoint = e.saveAcceptedChan()
- stateSinkObject.SaveValue(50, acceptedChanValue)
+ stateSinkObject.SaveValue(49, acceptedChanValue)
stateSinkObject.Save(0, &e.EndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
@@ -260,36 +259,35 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(28, &e.shutdownFlags)
stateSinkObject.Save(29, &e.sackPermitted)
stateSinkObject.Save(30, &e.sack)
- stateSinkObject.Save(31, &e.bindToDevice)
- stateSinkObject.Save(32, &e.delay)
- stateSinkObject.Save(33, &e.scoreboard)
- stateSinkObject.Save(34, &e.segmentQueue)
- stateSinkObject.Save(35, &e.synRcvdCount)
- stateSinkObject.Save(36, &e.userMSS)
- stateSinkObject.Save(37, &e.maxSynRetries)
- stateSinkObject.Save(38, &e.windowClamp)
- stateSinkObject.Save(39, &e.sndBufSize)
- stateSinkObject.Save(40, &e.sndBufUsed)
- stateSinkObject.Save(41, &e.sndClosed)
- stateSinkObject.Save(42, &e.sndBufInQueue)
- stateSinkObject.Save(43, &e.sndQueue)
- stateSinkObject.Save(44, &e.cc)
- stateSinkObject.Save(45, &e.packetTooBigCount)
- stateSinkObject.Save(46, &e.sndMTU)
- stateSinkObject.Save(47, &e.keepalive)
- stateSinkObject.Save(48, &e.userTimeout)
- stateSinkObject.Save(49, &e.deferAccept)
- stateSinkObject.Save(51, &e.rcv)
- stateSinkObject.Save(52, &e.snd)
- stateSinkObject.Save(53, &e.connectingAddress)
- stateSinkObject.Save(54, &e.amss)
- stateSinkObject.Save(55, &e.sendTOS)
- stateSinkObject.Save(56, &e.gso)
- stateSinkObject.Save(57, &e.tcpLingerTimeout)
- stateSinkObject.Save(58, &e.closed)
- stateSinkObject.Save(59, &e.txHash)
- stateSinkObject.Save(60, &e.owner)
- stateSinkObject.Save(61, &e.ops)
+ stateSinkObject.Save(31, &e.delay)
+ stateSinkObject.Save(32, &e.scoreboard)
+ stateSinkObject.Save(33, &e.segmentQueue)
+ stateSinkObject.Save(34, &e.synRcvdCount)
+ stateSinkObject.Save(35, &e.userMSS)
+ stateSinkObject.Save(36, &e.maxSynRetries)
+ stateSinkObject.Save(37, &e.windowClamp)
+ stateSinkObject.Save(38, &e.sndBufSize)
+ stateSinkObject.Save(39, &e.sndBufUsed)
+ stateSinkObject.Save(40, &e.sndClosed)
+ stateSinkObject.Save(41, &e.sndBufInQueue)
+ stateSinkObject.Save(42, &e.sndQueue)
+ stateSinkObject.Save(43, &e.cc)
+ stateSinkObject.Save(44, &e.packetTooBigCount)
+ stateSinkObject.Save(45, &e.sndMTU)
+ stateSinkObject.Save(46, &e.keepalive)
+ stateSinkObject.Save(47, &e.userTimeout)
+ stateSinkObject.Save(48, &e.deferAccept)
+ stateSinkObject.Save(50, &e.rcv)
+ stateSinkObject.Save(51, &e.snd)
+ stateSinkObject.Save(52, &e.connectingAddress)
+ stateSinkObject.Save(53, &e.amss)
+ stateSinkObject.Save(54, &e.sendTOS)
+ stateSinkObject.Save(55, &e.gso)
+ stateSinkObject.Save(56, &e.tcpLingerTimeout)
+ stateSinkObject.Save(57, &e.closed)
+ stateSinkObject.Save(58, &e.txHash)
+ stateSinkObject.Save(59, &e.owner)
+ stateSinkObject.Save(60, &e.ops)
}
func (e *endpoint) StateLoad(stateSourceObject state.Source) {
@@ -320,41 +318,40 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(28, &e.shutdownFlags)
stateSourceObject.Load(29, &e.sackPermitted)
stateSourceObject.Load(30, &e.sack)
- stateSourceObject.Load(31, &e.bindToDevice)
- stateSourceObject.Load(32, &e.delay)
- stateSourceObject.Load(33, &e.scoreboard)
- stateSourceObject.LoadWait(34, &e.segmentQueue)
- stateSourceObject.Load(35, &e.synRcvdCount)
- stateSourceObject.Load(36, &e.userMSS)
- stateSourceObject.Load(37, &e.maxSynRetries)
- stateSourceObject.Load(38, &e.windowClamp)
- stateSourceObject.Load(39, &e.sndBufSize)
- stateSourceObject.Load(40, &e.sndBufUsed)
- stateSourceObject.Load(41, &e.sndClosed)
- stateSourceObject.Load(42, &e.sndBufInQueue)
- stateSourceObject.LoadWait(43, &e.sndQueue)
- stateSourceObject.Load(44, &e.cc)
- stateSourceObject.Load(45, &e.packetTooBigCount)
- stateSourceObject.Load(46, &e.sndMTU)
- stateSourceObject.Load(47, &e.keepalive)
- stateSourceObject.Load(48, &e.userTimeout)
- stateSourceObject.Load(49, &e.deferAccept)
- stateSourceObject.LoadWait(51, &e.rcv)
- stateSourceObject.LoadWait(52, &e.snd)
- stateSourceObject.Load(53, &e.connectingAddress)
- stateSourceObject.Load(54, &e.amss)
- stateSourceObject.Load(55, &e.sendTOS)
- stateSourceObject.Load(56, &e.gso)
- stateSourceObject.Load(57, &e.tcpLingerTimeout)
- stateSourceObject.Load(58, &e.closed)
- stateSourceObject.Load(59, &e.txHash)
- stateSourceObject.Load(60, &e.owner)
- stateSourceObject.Load(61, &e.ops)
+ stateSourceObject.Load(31, &e.delay)
+ stateSourceObject.Load(32, &e.scoreboard)
+ stateSourceObject.LoadWait(33, &e.segmentQueue)
+ stateSourceObject.Load(34, &e.synRcvdCount)
+ stateSourceObject.Load(35, &e.userMSS)
+ stateSourceObject.Load(36, &e.maxSynRetries)
+ stateSourceObject.Load(37, &e.windowClamp)
+ stateSourceObject.Load(38, &e.sndBufSize)
+ stateSourceObject.Load(39, &e.sndBufUsed)
+ stateSourceObject.Load(40, &e.sndClosed)
+ stateSourceObject.Load(41, &e.sndBufInQueue)
+ stateSourceObject.LoadWait(42, &e.sndQueue)
+ stateSourceObject.Load(43, &e.cc)
+ stateSourceObject.Load(44, &e.packetTooBigCount)
+ stateSourceObject.Load(45, &e.sndMTU)
+ stateSourceObject.Load(46, &e.keepalive)
+ stateSourceObject.Load(47, &e.userTimeout)
+ stateSourceObject.Load(48, &e.deferAccept)
+ stateSourceObject.LoadWait(50, &e.rcv)
+ stateSourceObject.LoadWait(51, &e.snd)
+ stateSourceObject.Load(52, &e.connectingAddress)
+ stateSourceObject.Load(53, &e.amss)
+ stateSourceObject.Load(54, &e.sendTOS)
+ stateSourceObject.Load(55, &e.gso)
+ stateSourceObject.Load(56, &e.tcpLingerTimeout)
+ stateSourceObject.Load(57, &e.closed)
+ stateSourceObject.Load(58, &e.txHash)
+ stateSourceObject.Load(59, &e.owner)
+ stateSourceObject.Load(60, &e.ops)
stateSourceObject.LoadValue(4, new(string), func(y interface{}) { e.loadHardError(y.(string)) })
stateSourceObject.LoadValue(5, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) })
stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) })
- stateSourceObject.LoadValue(50, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) })
+ stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) })
stateSourceObject.AfterLoad(e.afterLoad)
}
@@ -724,6 +721,7 @@ func (s *sender) StateFields() []string {
"sndSsthresh",
"sndCAAckCount",
"outstanding",
+ "sackedOut",
"sndWnd",
"sndUna",
"sndNxt",
@@ -755,9 +753,9 @@ func (s *sender) StateSave(stateSinkObject state.Sink) {
var lastSendTimeValue unixTime = s.saveLastSendTime()
stateSinkObject.SaveValue(1, lastSendTimeValue)
var rttMeasureTimeValue unixTime = s.saveRttMeasureTime()
- stateSinkObject.SaveValue(13, rttMeasureTimeValue)
+ stateSinkObject.SaveValue(14, rttMeasureTimeValue)
var firstRetransmittedSegXmitTimeValue unixTime = s.saveFirstRetransmittedSegXmitTime()
- stateSinkObject.SaveValue(14, firstRetransmittedSegXmitTimeValue)
+ stateSinkObject.SaveValue(15, firstRetransmittedSegXmitTimeValue)
stateSinkObject.Save(0, &s.ep)
stateSinkObject.Save(2, &s.dupAckCount)
stateSinkObject.Save(3, &s.fr)
@@ -766,25 +764,26 @@ func (s *sender) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(6, &s.sndSsthresh)
stateSinkObject.Save(7, &s.sndCAAckCount)
stateSinkObject.Save(8, &s.outstanding)
- stateSinkObject.Save(9, &s.sndWnd)
- stateSinkObject.Save(10, &s.sndUna)
- stateSinkObject.Save(11, &s.sndNxt)
- stateSinkObject.Save(12, &s.rttMeasureSeqNum)
- stateSinkObject.Save(15, &s.closed)
- stateSinkObject.Save(16, &s.writeNext)
- stateSinkObject.Save(17, &s.writeList)
- stateSinkObject.Save(18, &s.rtt)
- stateSinkObject.Save(19, &s.rto)
- stateSinkObject.Save(20, &s.minRTO)
- stateSinkObject.Save(21, &s.maxRTO)
- stateSinkObject.Save(22, &s.maxRetries)
- stateSinkObject.Save(23, &s.maxPayloadSize)
- stateSinkObject.Save(24, &s.gso)
- stateSinkObject.Save(25, &s.sndWndScale)
- stateSinkObject.Save(26, &s.maxSentAck)
- stateSinkObject.Save(27, &s.state)
- stateSinkObject.Save(28, &s.cc)
- stateSinkObject.Save(29, &s.rc)
+ stateSinkObject.Save(9, &s.sackedOut)
+ stateSinkObject.Save(10, &s.sndWnd)
+ stateSinkObject.Save(11, &s.sndUna)
+ stateSinkObject.Save(12, &s.sndNxt)
+ stateSinkObject.Save(13, &s.rttMeasureSeqNum)
+ stateSinkObject.Save(16, &s.closed)
+ stateSinkObject.Save(17, &s.writeNext)
+ stateSinkObject.Save(18, &s.writeList)
+ stateSinkObject.Save(19, &s.rtt)
+ stateSinkObject.Save(20, &s.rto)
+ stateSinkObject.Save(21, &s.minRTO)
+ stateSinkObject.Save(22, &s.maxRTO)
+ stateSinkObject.Save(23, &s.maxRetries)
+ stateSinkObject.Save(24, &s.maxPayloadSize)
+ stateSinkObject.Save(25, &s.gso)
+ stateSinkObject.Save(26, &s.sndWndScale)
+ stateSinkObject.Save(27, &s.maxSentAck)
+ stateSinkObject.Save(28, &s.state)
+ stateSinkObject.Save(29, &s.cc)
+ stateSinkObject.Save(30, &s.rc)
}
func (s *sender) StateLoad(stateSourceObject state.Source) {
@@ -796,28 +795,29 @@ func (s *sender) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(6, &s.sndSsthresh)
stateSourceObject.Load(7, &s.sndCAAckCount)
stateSourceObject.Load(8, &s.outstanding)
- stateSourceObject.Load(9, &s.sndWnd)
- stateSourceObject.Load(10, &s.sndUna)
- stateSourceObject.Load(11, &s.sndNxt)
- stateSourceObject.Load(12, &s.rttMeasureSeqNum)
- stateSourceObject.Load(15, &s.closed)
- stateSourceObject.Load(16, &s.writeNext)
- stateSourceObject.Load(17, &s.writeList)
- stateSourceObject.Load(18, &s.rtt)
- stateSourceObject.Load(19, &s.rto)
- stateSourceObject.Load(20, &s.minRTO)
- stateSourceObject.Load(21, &s.maxRTO)
- stateSourceObject.Load(22, &s.maxRetries)
- stateSourceObject.Load(23, &s.maxPayloadSize)
- stateSourceObject.Load(24, &s.gso)
- stateSourceObject.Load(25, &s.sndWndScale)
- stateSourceObject.Load(26, &s.maxSentAck)
- stateSourceObject.Load(27, &s.state)
- stateSourceObject.Load(28, &s.cc)
- stateSourceObject.Load(29, &s.rc)
+ stateSourceObject.Load(9, &s.sackedOut)
+ stateSourceObject.Load(10, &s.sndWnd)
+ stateSourceObject.Load(11, &s.sndUna)
+ stateSourceObject.Load(12, &s.sndNxt)
+ stateSourceObject.Load(13, &s.rttMeasureSeqNum)
+ stateSourceObject.Load(16, &s.closed)
+ stateSourceObject.Load(17, &s.writeNext)
+ stateSourceObject.Load(18, &s.writeList)
+ stateSourceObject.Load(19, &s.rtt)
+ stateSourceObject.Load(20, &s.rto)
+ stateSourceObject.Load(21, &s.minRTO)
+ stateSourceObject.Load(22, &s.maxRTO)
+ stateSourceObject.Load(23, &s.maxRetries)
+ stateSourceObject.Load(24, &s.maxPayloadSize)
+ stateSourceObject.Load(25, &s.gso)
+ stateSourceObject.Load(26, &s.sndWndScale)
+ stateSourceObject.Load(27, &s.maxSentAck)
+ stateSourceObject.Load(28, &s.state)
+ stateSourceObject.Load(29, &s.cc)
+ stateSourceObject.Load(30, &s.rc)
stateSourceObject.LoadValue(1, new(unixTime), func(y interface{}) { s.loadLastSendTime(y.(unixTime)) })
- stateSourceObject.LoadValue(13, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) })
- stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(15, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) })
stateSourceObject.AfterLoad(s.afterLoad)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 763d1d654..9b9e4deb0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -109,7 +109,6 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
portFlags ports.Flags
- bindToDevice tcpip.NICID
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -226,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+}
+
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
@@ -511,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
+ so := e.SocketOptions()
+ if so.GetRecvError() {
+ so.QueueLocalErr(
+ tcpip.ErrMessageTooLong,
+ route.NetProto,
+ header.UDPMaximumPacketSize,
+ tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ Port: dstPort,
+ },
+ v,
+ )
+ }
return 0, nil, tcpip.ErrMessageTooLong
}
@@ -638,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
@@ -754,15 +778,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
delete(e.multicastMemberships, memToRemove)
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.mu.Lock()
- e.bindToDevice = id
- e.mu.Unlock()
-
case *tcpip.SocketDetachFilterOption:
return nil
}
@@ -838,11 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
e.mu.Unlock()
- case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -996,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1024,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1034,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1092,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
id.LocalPort = port
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1259,6 +1270,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1267,10 +1279,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
- // packets at "Parse" instead of when handling a packet.
- pkt.Data.CapLength(int(hdr.PayloadLength()))
-
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1304,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
@@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ // Linux passes the payload without the UDP header.
+ var payload []byte
+ udp := header.UDP(pkt.Data.ToView())
+ if len(udp) >= header.UDPMinimumSize {
+ payload = udp.Payload()
+ }
+
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ Payload: payload,
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.waiterQueue.Notify(waiter.EventErr)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
if e.EndpointState() == StateConnected {
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrConnectionRefused
- e.lastErrorMu.Unlock()
-
- e.waiterQueue.Notify(waiter.EventErr)
+ var errType byte
+ var errCode byte
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ errType = byte(header.ICMPv4DstUnreachable)
+ errCode = byte(header.ICMPv4PortUnreachable)
+ case header.IPv6ProtocolNumber:
+ errType = byte(header.ICMPv6DstUnreachable)
+ errCode = byte(header.ICMPv6PortUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
+ }
+ e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt)
return
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 14e4648cd..d7fc21f11 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index ec0a8c902..2b7726097 100644
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -73,7 +73,6 @@ func (e *endpoint) StateFields() []string {
"multicastAddr",
"multicastNICID",
"portFlags",
- "bindToDevice",
"lastError",
"boundBindToDevice",
"boundPortFlags",
@@ -91,7 +90,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
var lastErrorValue string = e.saveLastError()
- stateSinkObject.SaveValue(19, lastErrorValue)
+ stateSinkObject.SaveValue(18, lastErrorValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
@@ -109,15 +108,14 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(15, &e.multicastAddr)
stateSinkObject.Save(16, &e.multicastNICID)
stateSinkObject.Save(17, &e.portFlags)
- stateSinkObject.Save(18, &e.bindToDevice)
- stateSinkObject.Save(20, &e.boundBindToDevice)
- stateSinkObject.Save(21, &e.boundPortFlags)
- stateSinkObject.Save(22, &e.sendTOS)
- stateSinkObject.Save(23, &e.shutdownFlags)
- stateSinkObject.Save(24, &e.multicastMemberships)
- stateSinkObject.Save(25, &e.effectiveNetProtos)
- stateSinkObject.Save(26, &e.owner)
- stateSinkObject.Save(27, &e.ops)
+ stateSinkObject.Save(19, &e.boundBindToDevice)
+ stateSinkObject.Save(20, &e.boundPortFlags)
+ stateSinkObject.Save(21, &e.sendTOS)
+ stateSinkObject.Save(22, &e.shutdownFlags)
+ stateSinkObject.Save(23, &e.multicastMemberships)
+ stateSinkObject.Save(24, &e.effectiveNetProtos)
+ stateSinkObject.Save(25, &e.owner)
+ stateSinkObject.Save(26, &e.ops)
}
func (e *endpoint) StateLoad(stateSourceObject state.Source) {
@@ -138,17 +136,16 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(15, &e.multicastAddr)
stateSourceObject.Load(16, &e.multicastNICID)
stateSourceObject.Load(17, &e.portFlags)
- stateSourceObject.Load(18, &e.bindToDevice)
- stateSourceObject.Load(20, &e.boundBindToDevice)
- stateSourceObject.Load(21, &e.boundPortFlags)
- stateSourceObject.Load(22, &e.sendTOS)
- stateSourceObject.Load(23, &e.shutdownFlags)
- stateSourceObject.Load(24, &e.multicastMemberships)
- stateSourceObject.Load(25, &e.effectiveNetProtos)
- stateSourceObject.Load(26, &e.owner)
- stateSourceObject.Load(27, &e.ops)
+ stateSourceObject.Load(19, &e.boundBindToDevice)
+ stateSourceObject.Load(20, &e.boundPortFlags)
+ stateSourceObject.Load(21, &e.sendTOS)
+ stateSourceObject.Load(22, &e.shutdownFlags)
+ stateSourceObject.Load(23, &e.multicastMemberships)
+ stateSourceObject.Load(24, &e.effectiveNetProtos)
+ stateSourceObject.Load(25, &e.owner)
+ stateSourceObject.Load(26, &e.ops)
stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
- stateSourceObject.LoadValue(19, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
+ stateSourceObject.LoadValue(18, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
stateSourceObject.AfterLoad(e.afterLoad)
}