diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 4 |
6 files changed, 23 insertions, 46 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 74fe19e98..d1e4a7cb7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -504,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: r.LocalAddress, @@ -519,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } e.ID = id - e.route = r.Clone() + e.route = r e.RegisterNICID = nicID e.state = stateConnected diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5df703fb2..7befcfc9b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } @@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3e1041cbe..2d96a65bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() - s := sleep.Sleeper{} + var s sleep.Sleeper s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c944dccc0..0dc710276 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resolutionWaker := &sleep.Waker{} s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution + attemptedResolution := false for { switch index { case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - if err == tcpip.ErrNoLinkAddress { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - } else if err != nil { + if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { + if err != nil { h.ep.stats.SendErrors.NoRoute.Increment() } // Either success (err == nil) or failure. return err } + if attemptedResolution { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + return tcpip.ErrNoLinkAddress + } + attemptedResolution = true // Resolution not completed. Keep trying... case wakerForNotification: n := h.ep.fetchNotifications() if n¬ifyClose != 0 { - h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error { // complete completes the TCP 3-way handshake initiated by h.start(). func (h *handshake) complete() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resendWaker := sleep.Waker{} s.AddWaker(&resendWaker, wakerForResend) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Initialize the sleeper based on the wakers in funcs. - s := sleep.Sleeper{} + var s sleep.Sleeper for i := range funcs { s.AddWaker(funcs[i].w, i) } @@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const notification = 2 const timeWaitDone = 3 - s := sleep.Sleeper{} + var s sleep.Sleeper defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 2128206d7..c88e74bec 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2292,7 +2292,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.setEndpointState(StateConnecting) - e.route = r.Clone() + r.Acquire() + e.route = r e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d919fa011..24d0c2cb9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1017,7 +1017,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: e.ID.LocalAddress, @@ -1045,6 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } @@ -1055,7 +1055,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.boundBindToDevice = btd - e.route = r.Clone() + e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID e.effectiveNetProtos = netProtos |