summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-11-25 14:51:18 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-25 14:52:59 -0800
commit2485a4e2cb4aaee8f1a5e760541fb02e9090de44 (patch)
treebf20ce486235c1bb9c9261a64f24823c822dafd5 /pkg/tcpip/transport
parent4d59a5a62223b56927b37a00cd5a6dea577fe4c6 (diff)
Make stack.Route safe to access concurrently
Multiple goroutines may use the same stack.Route concurrently so the stack.Route should make sure that any functions called on it are thread-safe. Fixes #4073 PiperOrigin-RevId: 344320491
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go11
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go16
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go17
8 files changed, 47 insertions, 29 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 59ec54ca0..0a714498d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -72,7 +72,7 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
// linger is used for SO_LINGER socket option.
@@ -132,7 +132,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = stateClosed
@@ -272,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
var route *stack.Route
if to == nil {
- route = &e.route
+ route = e.route
if route.IsResolutionRequired() {
// Promote lock to exclusive if using a shared route,
@@ -313,7 +316,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
}
if route.IsResolutionRequired() {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b0b53b181..2b1022995 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -84,7 +84,7 @@ type endpoint struct {
bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
@@ -173,9 +173,11 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- if e.connected {
+ e.connected = false
+
+ if e.route != nil {
e.route.Release()
- e.connected = false
+ e.route = nil
}
e.closed = true
@@ -299,7 +301,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if e.route.IsResolutionRequired() {
- savedRoute := &e.route
+ savedRoute := e.route
// Promote lock to exclusive if using a shared route,
// given that it may need to change in finishWrite.
e.mu.RUnlock()
@@ -307,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Make sure that the route didn't change during the
// time we didn't hold the lock.
- if !e.connected || savedRoute != &e.route {
+ if !e.connected || savedRoute != e.route {
e.mu.Unlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
@@ -317,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return n, ch, err
}
- n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ n, ch, err := e.finishWrite(payloadBytes, e.route)
e.mu.RUnlock()
return n, ch, err
}
@@ -338,7 +340,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, err
}
- n, ch, err := e.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
e.mu.RUnlock()
return n, ch, err
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 7d97cbdc7..4a7e1c039 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if e.connected {
var err *tcpip.Error
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
+ // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
+ // remote address. We used to pass e.remote.RemoteAddress which was
+ // effectively the empty address but since moving e.route to hold a pointer
+ // to a route instead of the route by value, we pass the empty address
+ // directly. Obviously this was always wrong since we should provide the
+ // remote address we were connected to, to properly restore the route.
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e5adc383..5f2221f1b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
}
- if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err
}
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 88a632019..e38488d4d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -300,7 +300,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if ttl == 0 {
ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -361,7 +361,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -547,7 +547,7 @@ func (h *handshake) start() *tcpip.Error {
}
h.sendSYNOpts = synOpts
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -596,7 +596,7 @@ func (h *handshake) complete() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -939,7 +939,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, tcpFields{
+ err := e.sendTCP(e.route, tcpFields{
id: e.ID,
ttl: e.ttl,
tos: e.sendTOS,
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 64563a8ba..713a70b47 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -440,7 +440,7 @@ type endpoint struct {
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
boundNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
v6only bool
isConnectNotified bool
@@ -705,7 +705,7 @@ func (e *endpoint) UniqueID() uint64 {
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
// r, it will be used; otherwise, the maximum possible MSS will be used.
-func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
// TODO(b/143359391): Respect TCP Min and Max size.
maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
@@ -1173,7 +1173,11 @@ func (e *endpoint) cleanupLocked() {
e.boundPortFlags = ports.Flags{}
e.boundDest = tcpip.FullAddress{}
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
+
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2329aca4b..672159eed 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
ttl = route.DefaultTTL()
}
- return sendTCP(&route, tcpFields{
+ return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a7a405dcb..9d33a694b 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -99,7 +99,7 @@ type endpoint struct {
sndBufSize int
sndBufSizeMax int
state EndpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
dstPort uint16
v6only bool
ttl uint8
@@ -259,7 +259,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = StateClosed
@@ -368,7 +371,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -387,7 +390,7 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
- return stack.Route{}, 0, err
+ return nil, 0, err
}
return r, nicID, nil
}
@@ -459,7 +462,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
- route = &e.route
+ route = e.route
dstPort = e.dstPort
resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
// Promote lock to exclusive if using a shared route, given that it may
@@ -512,7 +515,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
dstPort = dst.Port
resolve = route.Resolve
}
@@ -1083,7 +1086,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
- e.route = stack.Route{}
+ e.route = nil
e.dstPort = 0
return nil