diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 44 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 97 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 210 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 120 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_state_autogen.go | 39 |
12 files changed, 339 insertions, 254 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 74fe19e98..d1e4a7cb7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -504,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: r.LocalAddress, @@ -519,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } e.ID = id - e.route = r.Clone() + e.route = r e.RegisterNICID = nicID e.state = stateConnected diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 9faab4b9e..e5e247342 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -366,6 +366,13 @@ func (ep *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (ep *endpoint) UpdateLastError(err *tcpip.Error) { + ep.lastErrorMu.Lock() + ep.lastError = err + ep.lastErrorMu.Unlock() +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { return tcpip.ErrNotSupported diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee3f11c1..7befcfc9b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } @@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -404,7 +377,7 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrAddressFamilyNotSupported } e.mu.Lock() @@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil @@ -620,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + e.mu.RLock() e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -632,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // sockets. if e.rcvClosed || !e.associated { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() return @@ -639,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return @@ -650,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // If bound to a NIC, only accept data for that NIC. if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() + e.mu.RUnlock() return } // If bound to an address, only accept data for that address. if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } } @@ -663,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // connected to. if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } @@ -697,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() + e.mu.RUnlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3e1041cbe..2d96a65bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() - s := sleep.Sleeper{} + var s sleep.Sleeper s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c944dccc0..0dc710276 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resolutionWaker := &sleep.Waker{} s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution + attemptedResolution := false for { switch index { case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - if err == tcpip.ErrNoLinkAddress { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - } else if err != nil { + if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { + if err != nil { h.ep.stats.SendErrors.NoRoute.Increment() } // Either success (err == nil) or failure. return err } + if attemptedResolution { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + return tcpip.ErrNoLinkAddress + } + attemptedResolution = true // Resolution not completed. Keep trying... case wakerForNotification: n := h.ep.fetchNotifications() if n¬ifyClose != 0 { - h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error { // complete completes the TCP 3-way handshake initiated by h.start(). func (h *handshake) complete() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resendWaker := sleep.Waker{} s.AddWaker(&resendWaker, wakerForResend) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Initialize the sleeper based on the wakers in funcs. - s := sleep.Sleeper{} + var s sleep.Sleeper for i := range funcs { s.AddWaker(funcs[i].w, i) } @@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const notification = 2 const timeWaitDone = 3 - s := sleep.Sleeper{} + var s sleep.Sleeper defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7a37c10bb..6e3c8860e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -502,9 +502,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // bindToDevice is set to the NIC on which to bind or disabled if 0. - bindToDevice tcpip.NICID - // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -1303,6 +1300,15 @@ func (e *endpoint) LastError() *tcpip.Error { return e.lastErrorLocked() } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.LockUser() + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + e.UnlockUser() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1812,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.LockUser() - e.bindToDevice = id - e.UnlockUser() - case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(*v) @@ -2004,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case *tcpip.BindToDeviceOption: - e.LockUser() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.UnlockUser() - case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.LockUser() @@ -2211,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } } + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2253,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) if err == tcpip.ErrPortInUse { return false, nil } @@ -2272,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // the selected port. e.ID = id e.isPortReserved = true - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil @@ -2283,7 +2280,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.setEndpointState(StateConnecting) - e.route = r.Clone() + r.Acquire() + e.route = r e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -2624,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2635,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2644,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. e.boundNICID = nic @@ -2708,6 +2707,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + // Linux passes the payload with the TCP header. We don't know if the TCP + // header even exists, it may not for fragmented packets. + Payload: pkt.Data.ToView(), + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.notifyProtocolGoroutine(notifyError) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { @@ -2722,16 +2756,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNoRoute - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNetworkUnreachable - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } @@ -2989,6 +3017,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { Ssthresh: e.snd.sndSsthresh, SndCAAckCount: e.snd.sndCAAckCount, Outstanding: e.snd.outstanding, + SackedOut: e.snd.sackedOut, SndWnd: e.snd.sndWnd, SndUna: e.snd.sndUna, SndNxt: e.snd.sndNxt, diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index f2b1b68da..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // If we started off with a window larger than what can he held in // the 16bit window field, we ceil the value to the max value. - // While ceiling, we still do not want to grow the right edge when - // not applicable. if scaledWnd > math.MaxUint16 { - if toGrow { - scaledWnd = seqnum.Size(math.MaxUint16) - } else { - scaledWnd = seqnum.Size(uint16(scaledWnd)) - } + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index baec762e1..cc991aba6 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -137,6 +137,9 @@ type sender struct { // that have been sent but not yet acknowledged. outstanding int + // sackedOut is the number of packets which are selectively acked. + sackedOut int + // sndWnd is the send window size. sndWnd seqnum.Size @@ -372,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } + oldMSS := s.maxPayloadSize s.maxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) @@ -394,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // Rewind writeNext to the first segment exceeding the MTU. Do nothing // if it is already before such a packet. + nextSeg := s.writeNext for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { if seg == s.writeNext { // We got to writeNext before we could find a segment @@ -401,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { break } - if seg.data.Size() > m { + if nextSeg == s.writeNext && seg.data.Size() > m { // We found a segment exceeding the MTU. Rewind // writeNext and try to retransmit it. - s.writeNext = seg - break + nextSeg = seg + } + + if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Update sackedOut for new maximum payload size. + s.sackedOut -= s.pCount(seg, oldMSS) + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } } // Since we likely reduced the number of outstanding packets, we may be // ready to send some more. + s.writeNext = nextSeg s.sendData() } @@ -629,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool { // pCount returns the number of packets in the segment. Due to GSO, a segment // can be composed of multiple packets. -func (s *sender) pCount(seg *segment) int { +func (s *sender) pCount(seg *segment, maxPayloadSize int) int { size := seg.data.Size() if size == 0 { return 1 } - return (size-1)/s.maxPayloadSize + 1 + return (size-1)/maxPayloadSize + 1 } // splitSeg splits a given segment at the size specified and inserts the @@ -1023,7 +1034,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg) + s.outstanding += s.pCount(seg, s.maxPayloadSize) s.writeNext = seg.Next() } @@ -1038,6 +1049,7 @@ func (s *sender) enterRecovery() { // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. s.sndCwnd = s.sndSsthresh + 3 + s.sackedOut = 0 s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding @@ -1207,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) s.rc.detectReorder(seg) seg.acked = true + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } seg = seg.Next() } @@ -1380,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg) + prevCount := s.pCount(seg, s.maxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg) + s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) break } @@ -1399,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.writeList.Remove(seg) - // If SACK is enabled then Only reduce outstanding if + // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg) + s.outstanding -= s.pCount(seg, s.maxPayloadSize) + } else { + s.sackedOut -= s.pCount(seg, s.maxPayloadSize) } seg.decRef() ackLeft -= datalen diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 8eba0efeb..5922083a9 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -187,7 +187,6 @@ func (e *endpoint) StateFields() []string { "shutdownFlags", "sackPermitted", "sack", - "bindToDevice", "delay", "scoreboard", "segmentQueue", @@ -232,7 +231,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { var recentTSTimeValue unixTime = e.saveRecentTSTime() stateSinkObject.SaveValue(26, recentTSTimeValue) var acceptedChanValue []*endpoint = e.saveAcceptedChan() - stateSinkObject.SaveValue(50, acceptedChanValue) + stateSinkObject.SaveValue(49, acceptedChanValue) stateSinkObject.Save(0, &e.EndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) @@ -260,36 +259,35 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(28, &e.shutdownFlags) stateSinkObject.Save(29, &e.sackPermitted) stateSinkObject.Save(30, &e.sack) - stateSinkObject.Save(31, &e.bindToDevice) - stateSinkObject.Save(32, &e.delay) - stateSinkObject.Save(33, &e.scoreboard) - stateSinkObject.Save(34, &e.segmentQueue) - stateSinkObject.Save(35, &e.synRcvdCount) - stateSinkObject.Save(36, &e.userMSS) - stateSinkObject.Save(37, &e.maxSynRetries) - stateSinkObject.Save(38, &e.windowClamp) - stateSinkObject.Save(39, &e.sndBufSize) - stateSinkObject.Save(40, &e.sndBufUsed) - stateSinkObject.Save(41, &e.sndClosed) - stateSinkObject.Save(42, &e.sndBufInQueue) - stateSinkObject.Save(43, &e.sndQueue) - stateSinkObject.Save(44, &e.cc) - stateSinkObject.Save(45, &e.packetTooBigCount) - stateSinkObject.Save(46, &e.sndMTU) - stateSinkObject.Save(47, &e.keepalive) - stateSinkObject.Save(48, &e.userTimeout) - stateSinkObject.Save(49, &e.deferAccept) - stateSinkObject.Save(51, &e.rcv) - stateSinkObject.Save(52, &e.snd) - stateSinkObject.Save(53, &e.connectingAddress) - stateSinkObject.Save(54, &e.amss) - stateSinkObject.Save(55, &e.sendTOS) - stateSinkObject.Save(56, &e.gso) - stateSinkObject.Save(57, &e.tcpLingerTimeout) - stateSinkObject.Save(58, &e.closed) - stateSinkObject.Save(59, &e.txHash) - stateSinkObject.Save(60, &e.owner) - stateSinkObject.Save(61, &e.ops) + stateSinkObject.Save(31, &e.delay) + stateSinkObject.Save(32, &e.scoreboard) + stateSinkObject.Save(33, &e.segmentQueue) + stateSinkObject.Save(34, &e.synRcvdCount) + stateSinkObject.Save(35, &e.userMSS) + stateSinkObject.Save(36, &e.maxSynRetries) + stateSinkObject.Save(37, &e.windowClamp) + stateSinkObject.Save(38, &e.sndBufSize) + stateSinkObject.Save(39, &e.sndBufUsed) + stateSinkObject.Save(40, &e.sndClosed) + stateSinkObject.Save(41, &e.sndBufInQueue) + stateSinkObject.Save(42, &e.sndQueue) + stateSinkObject.Save(43, &e.cc) + stateSinkObject.Save(44, &e.packetTooBigCount) + stateSinkObject.Save(45, &e.sndMTU) + stateSinkObject.Save(46, &e.keepalive) + stateSinkObject.Save(47, &e.userTimeout) + stateSinkObject.Save(48, &e.deferAccept) + stateSinkObject.Save(50, &e.rcv) + stateSinkObject.Save(51, &e.snd) + stateSinkObject.Save(52, &e.connectingAddress) + stateSinkObject.Save(53, &e.amss) + stateSinkObject.Save(54, &e.sendTOS) + stateSinkObject.Save(55, &e.gso) + stateSinkObject.Save(56, &e.tcpLingerTimeout) + stateSinkObject.Save(57, &e.closed) + stateSinkObject.Save(58, &e.txHash) + stateSinkObject.Save(59, &e.owner) + stateSinkObject.Save(60, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { @@ -320,41 +318,40 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(28, &e.shutdownFlags) stateSourceObject.Load(29, &e.sackPermitted) stateSourceObject.Load(30, &e.sack) - stateSourceObject.Load(31, &e.bindToDevice) - stateSourceObject.Load(32, &e.delay) - stateSourceObject.Load(33, &e.scoreboard) - stateSourceObject.LoadWait(34, &e.segmentQueue) - stateSourceObject.Load(35, &e.synRcvdCount) - stateSourceObject.Load(36, &e.userMSS) - stateSourceObject.Load(37, &e.maxSynRetries) - stateSourceObject.Load(38, &e.windowClamp) - stateSourceObject.Load(39, &e.sndBufSize) - stateSourceObject.Load(40, &e.sndBufUsed) - stateSourceObject.Load(41, &e.sndClosed) - stateSourceObject.Load(42, &e.sndBufInQueue) - stateSourceObject.LoadWait(43, &e.sndQueue) - stateSourceObject.Load(44, &e.cc) - stateSourceObject.Load(45, &e.packetTooBigCount) - stateSourceObject.Load(46, &e.sndMTU) - stateSourceObject.Load(47, &e.keepalive) - stateSourceObject.Load(48, &e.userTimeout) - stateSourceObject.Load(49, &e.deferAccept) - stateSourceObject.LoadWait(51, &e.rcv) - stateSourceObject.LoadWait(52, &e.snd) - stateSourceObject.Load(53, &e.connectingAddress) - stateSourceObject.Load(54, &e.amss) - stateSourceObject.Load(55, &e.sendTOS) - stateSourceObject.Load(56, &e.gso) - stateSourceObject.Load(57, &e.tcpLingerTimeout) - stateSourceObject.Load(58, &e.closed) - stateSourceObject.Load(59, &e.txHash) - stateSourceObject.Load(60, &e.owner) - stateSourceObject.Load(61, &e.ops) + stateSourceObject.Load(31, &e.delay) + stateSourceObject.Load(32, &e.scoreboard) + stateSourceObject.LoadWait(33, &e.segmentQueue) + stateSourceObject.Load(34, &e.synRcvdCount) + stateSourceObject.Load(35, &e.userMSS) + stateSourceObject.Load(36, &e.maxSynRetries) + stateSourceObject.Load(37, &e.windowClamp) + stateSourceObject.Load(38, &e.sndBufSize) + stateSourceObject.Load(39, &e.sndBufUsed) + stateSourceObject.Load(40, &e.sndClosed) + stateSourceObject.Load(41, &e.sndBufInQueue) + stateSourceObject.LoadWait(42, &e.sndQueue) + stateSourceObject.Load(43, &e.cc) + stateSourceObject.Load(44, &e.packetTooBigCount) + stateSourceObject.Load(45, &e.sndMTU) + stateSourceObject.Load(46, &e.keepalive) + stateSourceObject.Load(47, &e.userTimeout) + stateSourceObject.Load(48, &e.deferAccept) + stateSourceObject.LoadWait(50, &e.rcv) + stateSourceObject.LoadWait(51, &e.snd) + stateSourceObject.Load(52, &e.connectingAddress) + stateSourceObject.Load(53, &e.amss) + stateSourceObject.Load(54, &e.sendTOS) + stateSourceObject.Load(55, &e.gso) + stateSourceObject.Load(56, &e.tcpLingerTimeout) + stateSourceObject.Load(57, &e.closed) + stateSourceObject.Load(58, &e.txHash) + stateSourceObject.Load(59, &e.owner) + stateSourceObject.Load(60, &e.ops) stateSourceObject.LoadValue(4, new(string), func(y interface{}) { e.loadHardError(y.(string)) }) stateSourceObject.LoadValue(5, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) - stateSourceObject.LoadValue(50, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) + stateSourceObject.LoadValue(49, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) stateSourceObject.AfterLoad(e.afterLoad) } @@ -724,6 +721,7 @@ func (s *sender) StateFields() []string { "sndSsthresh", "sndCAAckCount", "outstanding", + "sackedOut", "sndWnd", "sndUna", "sndNxt", @@ -755,9 +753,9 @@ func (s *sender) StateSave(stateSinkObject state.Sink) { var lastSendTimeValue unixTime = s.saveLastSendTime() stateSinkObject.SaveValue(1, lastSendTimeValue) var rttMeasureTimeValue unixTime = s.saveRttMeasureTime() - stateSinkObject.SaveValue(13, rttMeasureTimeValue) + stateSinkObject.SaveValue(14, rttMeasureTimeValue) var firstRetransmittedSegXmitTimeValue unixTime = s.saveFirstRetransmittedSegXmitTime() - stateSinkObject.SaveValue(14, firstRetransmittedSegXmitTimeValue) + stateSinkObject.SaveValue(15, firstRetransmittedSegXmitTimeValue) stateSinkObject.Save(0, &s.ep) stateSinkObject.Save(2, &s.dupAckCount) stateSinkObject.Save(3, &s.fr) @@ -766,25 +764,26 @@ func (s *sender) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(6, &s.sndSsthresh) stateSinkObject.Save(7, &s.sndCAAckCount) stateSinkObject.Save(8, &s.outstanding) - stateSinkObject.Save(9, &s.sndWnd) - stateSinkObject.Save(10, &s.sndUna) - stateSinkObject.Save(11, &s.sndNxt) - stateSinkObject.Save(12, &s.rttMeasureSeqNum) - stateSinkObject.Save(15, &s.closed) - stateSinkObject.Save(16, &s.writeNext) - stateSinkObject.Save(17, &s.writeList) - stateSinkObject.Save(18, &s.rtt) - stateSinkObject.Save(19, &s.rto) - stateSinkObject.Save(20, &s.minRTO) - stateSinkObject.Save(21, &s.maxRTO) - stateSinkObject.Save(22, &s.maxRetries) - stateSinkObject.Save(23, &s.maxPayloadSize) - stateSinkObject.Save(24, &s.gso) - stateSinkObject.Save(25, &s.sndWndScale) - stateSinkObject.Save(26, &s.maxSentAck) - stateSinkObject.Save(27, &s.state) - stateSinkObject.Save(28, &s.cc) - stateSinkObject.Save(29, &s.rc) + stateSinkObject.Save(9, &s.sackedOut) + stateSinkObject.Save(10, &s.sndWnd) + stateSinkObject.Save(11, &s.sndUna) + stateSinkObject.Save(12, &s.sndNxt) + stateSinkObject.Save(13, &s.rttMeasureSeqNum) + stateSinkObject.Save(16, &s.closed) + stateSinkObject.Save(17, &s.writeNext) + stateSinkObject.Save(18, &s.writeList) + stateSinkObject.Save(19, &s.rtt) + stateSinkObject.Save(20, &s.rto) + stateSinkObject.Save(21, &s.minRTO) + stateSinkObject.Save(22, &s.maxRTO) + stateSinkObject.Save(23, &s.maxRetries) + stateSinkObject.Save(24, &s.maxPayloadSize) + stateSinkObject.Save(25, &s.gso) + stateSinkObject.Save(26, &s.sndWndScale) + stateSinkObject.Save(27, &s.maxSentAck) + stateSinkObject.Save(28, &s.state) + stateSinkObject.Save(29, &s.cc) + stateSinkObject.Save(30, &s.rc) } func (s *sender) StateLoad(stateSourceObject state.Source) { @@ -796,28 +795,29 @@ func (s *sender) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(6, &s.sndSsthresh) stateSourceObject.Load(7, &s.sndCAAckCount) stateSourceObject.Load(8, &s.outstanding) - stateSourceObject.Load(9, &s.sndWnd) - stateSourceObject.Load(10, &s.sndUna) - stateSourceObject.Load(11, &s.sndNxt) - stateSourceObject.Load(12, &s.rttMeasureSeqNum) - stateSourceObject.Load(15, &s.closed) - stateSourceObject.Load(16, &s.writeNext) - stateSourceObject.Load(17, &s.writeList) - stateSourceObject.Load(18, &s.rtt) - stateSourceObject.Load(19, &s.rto) - stateSourceObject.Load(20, &s.minRTO) - stateSourceObject.Load(21, &s.maxRTO) - stateSourceObject.Load(22, &s.maxRetries) - stateSourceObject.Load(23, &s.maxPayloadSize) - stateSourceObject.Load(24, &s.gso) - stateSourceObject.Load(25, &s.sndWndScale) - stateSourceObject.Load(26, &s.maxSentAck) - stateSourceObject.Load(27, &s.state) - stateSourceObject.Load(28, &s.cc) - stateSourceObject.Load(29, &s.rc) + stateSourceObject.Load(9, &s.sackedOut) + stateSourceObject.Load(10, &s.sndWnd) + stateSourceObject.Load(11, &s.sndUna) + stateSourceObject.Load(12, &s.sndNxt) + stateSourceObject.Load(13, &s.rttMeasureSeqNum) + stateSourceObject.Load(16, &s.closed) + stateSourceObject.Load(17, &s.writeNext) + stateSourceObject.Load(18, &s.writeList) + stateSourceObject.Load(19, &s.rtt) + stateSourceObject.Load(20, &s.rto) + stateSourceObject.Load(21, &s.minRTO) + stateSourceObject.Load(22, &s.maxRTO) + stateSourceObject.Load(23, &s.maxRetries) + stateSourceObject.Load(24, &s.maxPayloadSize) + stateSourceObject.Load(25, &s.gso) + stateSourceObject.Load(26, &s.sndWndScale) + stateSourceObject.Load(27, &s.maxSentAck) + stateSourceObject.Load(28, &s.state) + stateSourceObject.Load(29, &s.cc) + stateSourceObject.Load(30, &s.rc) stateSourceObject.LoadValue(1, new(unixTime), func(y interface{}) { s.loadLastSendTime(y.(unixTime)) }) - stateSourceObject.LoadValue(13, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) }) - stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) }) + stateSourceObject.LoadValue(14, new(unixTime), func(y interface{}) { s.loadRttMeasureTime(y.(unixTime)) }) + stateSourceObject.LoadValue(15, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) }) stateSourceObject.AfterLoad(s.afterLoad) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 763d1d654..9b9e4deb0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,7 +109,6 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID portFlags ports.Flags - bindToDevice tcpip.NICID lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -226,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -511,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. + so := e.SocketOptions() + if so.GetRecvError() { + so.QueueLocalErr( + tcpip.ErrMessageTooLong, + route.NetProto, + header.UDPMaximumPacketSize, + tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + Port: dstPort, + }, + v, + ) + } return 0, nil, tcpip.ErrMessageTooLong } @@ -638,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { @@ -754,15 +778,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { delete(e.multicastMemberships, memToRemove) - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.mu.Lock() - e.bindToDevice = id - e.mu.Unlock() - case *tcpip.SocketDetachFilterOption: return nil } @@ -838,11 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } e.mu.Unlock() - case *tcpip.BindToDeviceOption: - e.mu.RLock() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -996,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: e.ID.LocalAddress, @@ -1024,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } @@ -1034,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.boundBindToDevice = btd - e.route = r.Clone() + e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID e.effectiveNetProtos = netProtos @@ -1092,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { - return id, e.bindToDevice, err + return id, bindToDevice, err } id.LocalPort = port } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - return id, e.bindToDevice, err + return id, bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1259,6 +1270,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. @@ -1267,10 +1279,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap - // packets at "Parse" instead of when handling a packet. - pkt.Data.CapLength(int(hdr.PayloadLength())) - if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() @@ -1304,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, @@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + // Linux passes the payload without the UDP header. + var payload []byte + udp := header.UDP(pkt.Data.ToView()) + if len(udp) >= header.UDPMinimumSize { + payload = udp.Payload() + } + + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + Payload: payload, + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.waiterQueue.Notify(waiter.EventErr) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { if e.EndpointState() == StateConnected { - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrConnectionRefused - e.lastErrorMu.Unlock() - - e.waiterQueue.Notify(waiter.EventErr) + var errType byte + var errCode byte + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + errType = byte(header.ICMPv4DstUnreachable) + errCode = byte(header.ICMPv4PortUnreachable) + case header.IPv6ProtocolNumber: + errType = byte(header.ICMPv6DstUnreachable) + errCode = byte(header.ICMPv6PortUnreachable) + default: + panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) + } + e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt) return } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 14e4648cd..d7fc21f11 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index ec0a8c902..2b7726097 100644 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -73,7 +73,6 @@ func (e *endpoint) StateFields() []string { "multicastAddr", "multicastNICID", "portFlags", - "bindToDevice", "lastError", "boundBindToDevice", "boundPortFlags", @@ -91,7 +90,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) var lastErrorValue string = e.saveLastError() - stateSinkObject.SaveValue(19, lastErrorValue) + stateSinkObject.SaveValue(18, lastErrorValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) @@ -109,15 +108,14 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(15, &e.multicastAddr) stateSinkObject.Save(16, &e.multicastNICID) stateSinkObject.Save(17, &e.portFlags) - stateSinkObject.Save(18, &e.bindToDevice) - stateSinkObject.Save(20, &e.boundBindToDevice) - stateSinkObject.Save(21, &e.boundPortFlags) - stateSinkObject.Save(22, &e.sendTOS) - stateSinkObject.Save(23, &e.shutdownFlags) - stateSinkObject.Save(24, &e.multicastMemberships) - stateSinkObject.Save(25, &e.effectiveNetProtos) - stateSinkObject.Save(26, &e.owner) - stateSinkObject.Save(27, &e.ops) + stateSinkObject.Save(19, &e.boundBindToDevice) + stateSinkObject.Save(20, &e.boundPortFlags) + stateSinkObject.Save(21, &e.sendTOS) + stateSinkObject.Save(22, &e.shutdownFlags) + stateSinkObject.Save(23, &e.multicastMemberships) + stateSinkObject.Save(24, &e.effectiveNetProtos) + stateSinkObject.Save(25, &e.owner) + stateSinkObject.Save(26, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { @@ -138,17 +136,16 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(15, &e.multicastAddr) stateSourceObject.Load(16, &e.multicastNICID) stateSourceObject.Load(17, &e.portFlags) - stateSourceObject.Load(18, &e.bindToDevice) - stateSourceObject.Load(20, &e.boundBindToDevice) - stateSourceObject.Load(21, &e.boundPortFlags) - stateSourceObject.Load(22, &e.sendTOS) - stateSourceObject.Load(23, &e.shutdownFlags) - stateSourceObject.Load(24, &e.multicastMemberships) - stateSourceObject.Load(25, &e.effectiveNetProtos) - stateSourceObject.Load(26, &e.owner) - stateSourceObject.Load(27, &e.ops) + stateSourceObject.Load(19, &e.boundBindToDevice) + stateSourceObject.Load(20, &e.boundPortFlags) + stateSourceObject.Load(21, &e.sendTOS) + stateSourceObject.Load(22, &e.shutdownFlags) + stateSourceObject.Load(23, &e.multicastMemberships) + stateSourceObject.Load(24, &e.effectiveNetProtos) + stateSourceObject.Load(25, &e.owner) + stateSourceObject.Load(26, &e.ops) stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) - stateSourceObject.LoadValue(19, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) + stateSourceObject.LoadValue(18, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) stateSourceObject.AfterLoad(e.afterLoad) } |