summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go35
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go21
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go3
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go4
6 files changed, 23 insertions, 46 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 74fe19e98..d1e4a7cb7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -504,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
@@ -519,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
e.ID = id
- e.route = r.Clone()
+ e.route = r
e.RegisterNICID = nicID
e.state = stateConnected
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 5df703fb2..7befcfc9b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
e.mu.RLock()
+ defer e.mu.RUnlock()
if e.closed {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
@@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if e.route.IsResolutionRequired() {
- savedRoute := e.route
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in finishWrite.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Make sure that the route didn't change during the
- // time we didn't hold the lock.
- if !e.connected || savedRoute != e.route {
- e.mu.Unlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
- n, ch, err := e.finishWrite(payloadBytes, savedRoute)
- e.mu.Unlock()
- return n, ch, err
- }
-
- n, ch, err := e.finishWrite(payloadBytes, e.route)
- e.mu.RUnlock()
- return n, ch, err
+ return e.finishWrite(payloadBytes, e.route)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
@@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
- e.mu.RUnlock()
return n, ch, err
}
@@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer route.Release()
if e.associated {
// Re-register the endpoint with the appropriate NIC.
if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ route.Release()
return err
}
e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
@@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the route we've connected via.
- e.route = route.Clone()
+ e.route = route
e.connected = true
return nil
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3e1041cbe..2d96a65bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index c944dccc0..0dc710276 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error {
func (h *handshake) resolveRoute() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error {
// Initial action is to resolve route.
index := wakerForResolution
+ attemptedResolution := false
for {
switch index {
case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
- if err == tcpip.ErrNoLinkAddress {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- } else if err != nil {
+ if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
+ if err != nil {
h.ep.stats.SendErrors.NoRoute.Increment()
}
// Either success (err == nil) or failure.
return err
}
+ if attemptedResolution {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ return tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
// Resolution not completed. Keep trying...
case wakerForNotification:
n := h.ep.fetchNotifications()
if n&notifyClose != 0 {
- h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error {
// complete completes the TCP 3-way handshake initiated by h.start().
func (h *handshake) complete() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resendWaker := sleep.Waker{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Initialize the sleeper based on the wakers in funcs.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
for i := range funcs {
s.AddWaker(funcs[i].w, i)
}
@@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const notification = 2
const timeWaitDone = 3
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 2128206d7..c88e74bec 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2292,7 +2292,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.isRegistered = true
e.setEndpointState(StateConnecting)
- e.route = r.Clone()
+ r.Acquire()
+ e.route = r
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d919fa011..24d0c2cb9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1017,7 +1017,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1045,6 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1055,7 +1055,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos