summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go11
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go16
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go17
8 files changed, 47 insertions, 29 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 59ec54ca0..0a714498d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -72,7 +72,7 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
// linger is used for SO_LINGER socket option.
@@ -132,7 +132,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = stateClosed
@@ -272,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
var route *stack.Route
if to == nil {
- route = &e.route
+ route = e.route
if route.IsResolutionRequired() {
// Promote lock to exclusive if using a shared route,
@@ -313,7 +316,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
}
if route.IsResolutionRequired() {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b0b53b181..2b1022995 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -84,7 +84,7 @@ type endpoint struct {
bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
// linger is used for SO_LINGER socket option.
linger tcpip.LingerOption
@@ -173,9 +173,11 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- if e.connected {
+ e.connected = false
+
+ if e.route != nil {
e.route.Release()
- e.connected = false
+ e.route = nil
}
e.closed = true
@@ -299,7 +301,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if e.route.IsResolutionRequired() {
- savedRoute := &e.route
+ savedRoute := e.route
// Promote lock to exclusive if using a shared route,
// given that it may need to change in finishWrite.
e.mu.RUnlock()
@@ -307,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Make sure that the route didn't change during the
// time we didn't hold the lock.
- if !e.connected || savedRoute != &e.route {
+ if !e.connected || savedRoute != e.route {
e.mu.Unlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
@@ -317,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return n, ch, err
}
- n, ch, err := e.finishWrite(payloadBytes, &e.route)
+ n, ch, err := e.finishWrite(payloadBytes, e.route)
e.mu.RUnlock()
return n, ch, err
}
@@ -338,7 +340,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, err
}
- n, ch, err := e.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
e.mu.RUnlock()
return n, ch, err
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 7d97cbdc7..4a7e1c039 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if e.connected {
var err *tcpip.Error
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
+ // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
+ // remote address. We used to pass e.remote.RemoteAddress which was
+ // effectively the empty address but since moving e.route to hold a pointer
+ // to a route instead of the route by value, we pass the empty address
+ // directly. Obviously this was always wrong since we should provide the
+ // remote address we were connected to, to properly restore the route.
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e5adc383..5f2221f1b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
}
- if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err
}
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 88a632019..e38488d4d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -300,7 +300,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if ttl == 0 {
ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -361,7 +361,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -547,7 +547,7 @@ func (h *handshake) start() *tcpip.Error {
}
h.sendSYNOpts = synOpts
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -596,7 +596,7 @@ func (h *handshake) complete() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -939,7 +939,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, tcpFields{
+ err := e.sendTCP(e.route, tcpFields{
id: e.ID,
ttl: e.ttl,
tos: e.sendTOS,
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 64563a8ba..713a70b47 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -440,7 +440,7 @@ type endpoint struct {
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
boundNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
v6only bool
isConnectNotified bool
@@ -705,7 +705,7 @@ func (e *endpoint) UniqueID() uint64 {
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
// r, it will be used; otherwise, the maximum possible MSS will be used.
-func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
// TODO(b/143359391): Respect TCP Min and Max size.
maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
@@ -1173,7 +1173,11 @@ func (e *endpoint) cleanupLocked() {
e.boundPortFlags = ports.Flags{}
e.boundDest = tcpip.FullAddress{}
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
+
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2329aca4b..672159eed 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
ttl = route.DefaultTTL()
}
- return sendTCP(&route, tcpFields{
+ return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a7a405dcb..9d33a694b 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -99,7 +99,7 @@ type endpoint struct {
sndBufSize int
sndBufSizeMax int
state EndpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
dstPort uint16
v6only bool
ttl uint8
@@ -259,7 +259,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = StateClosed
@@ -368,7 +371,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -387,7 +390,7 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
- return stack.Route{}, 0, err
+ return nil, 0, err
}
return r, nicID, nil
}
@@ -459,7 +462,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
- route = &e.route
+ route = e.route
dstPort = e.dstPort
resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
// Promote lock to exclusive if using a shared route, given that it may
@@ -512,7 +515,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
dstPort = dst.Port
resolve = route.Resolve
}
@@ -1083,7 +1086,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
- e.route = stack.Route{}
+ e.route = nil
e.dstPort = 0
return nil