From 680398ab76c92c88674c5770015fd100b33e137a Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 12 Jan 2021 18:42:18 -0800 Subject: Remove unnecessary closure PiperOrigin-RevId: 351491836 --- pkg/tcpip/transport/tcp/endpoint.go | 64 ++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 36 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 281f4cd58..9275d5f3a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1544,46 +1544,38 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { - // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - - // Do the work inline. - e.handleWrite() - e.UnlockUser() - return int64(len(v)), nil, nil - } - - if opts.Atomic { - // Locks released in queueAndSend() - return queueAndSend() - } + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - e.LockUser() - e.sndBufMu.Lock() - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } + // Add data to the send queue. + s := newOutgoingSegment(e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() - // Locks released in queueAndSend() - return queueAndSend() + // Do the work inline. + e.handleWrite() + e.UnlockUser() + return int64(len(v)), nil, nil } // selectWindowLocked returns the new window without checking for shrinking or scaling -- cgit v1.2.3 From 62b4c2f5173dfa75387c079bd3dd6d5e5c3abae9 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 12 Jan 2021 19:34:43 -0800 Subject: Drop TransportEndpointID from HandleControlPacket When a control packet is delivered, it is delivered to a transport endpoint with a matching stack.TransportEndpointID so there is no need to pass the ID to the endpoint as it already knows its ID. PiperOrigin-RevId: 351497588 --- pkg/tcpip/stack/registration.go | 2 +- pkg/tcpip/stack/transport_demuxer.go | 5 +++-- pkg/tcpip/stack/transport_test.go | 2 +- pkg/tcpip/transport/icmp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 16 ++++++++-------- pkg/tcpip/transport/udp/endpoint.go | 14 +++++++------- 6 files changed, 21 insertions(+), 20 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 7e83b7fbb..4795208b4 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -84,7 +84,7 @@ type TransportEndpoint interface { // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) + HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f183ec6e4..07b2818d2 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,7 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +// handleControlPacket delivers a control packet to the transport endpoint +// identified by id. func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -199,7 +200,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a5facf578..0ff32c6ea 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -237,7 +237,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 2eb4457df..c32fe5c4f 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -789,7 +789,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9275d5f3a..25b180fa5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2728,7 +2728,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2747,13 +2747,13 @@ func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, e Payload: pkt.Data.ToView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: id.RemoteAddress, - Port: id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: id.LocalAddress, - Port: id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -2764,7 +2764,7 @@ func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, e } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2777,10 +2777,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) + e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 075de1db0..5d87f3a7e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1352,7 +1352,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1376,13 +1376,13 @@ func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, e Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: id.RemoteAddress, - Port: id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: id.LocalAddress, - Port: id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -1393,7 +1393,7 @@ func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, e } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { if e.EndpointState() == StateConnected { var errType byte @@ -1408,7 +1408,7 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C default: panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) } - e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt) + e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt) return } } -- cgit v1.2.3