diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint_state.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 17 |
8 files changed, 47 insertions, 29 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 59ec54ca0..0a714498d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -72,7 +72,7 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` // linger is used for SO_LINGER socket option. @@ -132,7 +132,10 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. e.state = stateClosed @@ -272,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c var route *stack.Route if to == nil { - route = &e.route + route = e.route if route.IsResolutionRequired() { // Promote lock to exclusive if using a shared route, @@ -313,7 +316,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r } if route.IsResolutionRequired() { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b0b53b181..2b1022995 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -84,7 +84,7 @@ type endpoint struct { bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` // linger is used for SO_LINGER socket option. linger tcpip.LingerOption @@ -173,9 +173,11 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - if e.connected { + e.connected = false + + if e.route != nil { e.route.Release() - e.connected = false + e.route = nil } e.closed = true @@ -299,7 +301,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if e.route.IsResolutionRequired() { - savedRoute := &e.route + savedRoute := e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. e.mu.RUnlock() @@ -307,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Make sure that the route didn't change during the // time we didn't hold the lock. - if !e.connected || savedRoute != &e.route { + if !e.connected || savedRoute != e.route { e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } @@ -317,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return n, ch, err } - n, ch, err := e.finishWrite(payloadBytes, &e.route) + n, ch, err := e.finishWrite(payloadBytes, e.route) e.mu.RUnlock() return n, ch, err } @@ -338,7 +340,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - n, ch, err := e.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, route) route.Release() e.mu.RUnlock() return n, ch, err diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 7d97cbdc7..4a7e1c039 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if e.connected { var err *tcpip.Error - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) + // TODO(gvisor.dev/issue/4906): Properly restore the route with the right + // remote address. We used to pass e.remote.RemoteAddress which was + // effectively the empty address but since moving e.route to hold a pointer + // to a route instead of the route by value, we pass the empty address + // directly. Obviously this was always wrong since we should provide the + // remote address we were connected to, to properly restore the route. + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6e5adc383..5f2221f1b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -599,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, } - if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + if err := e.sendSynTCP(route, fields, synOpts); err != nil { return err } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 88a632019..e38488d4d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -300,7 +300,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -361,7 +361,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -547,7 +547,7 @@ func (h *handshake) start() *tcpip.Error { } h.sendSYNOpts = synOpts - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -596,7 +596,7 @@ func (h *handshake) complete() *tcpip.Error { // the connection with another ACK or data (as ACKs are never // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -939,7 +939,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, tcpFields{ + err := e.sendTCP(e.route, tcpFields{ id: e.ID, ttl: e.ttl, tos: e.sendTOS, diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 64563a8ba..713a70b47 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -440,7 +440,7 @@ type endpoint struct { isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` boundNICID tcpip.NICID - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 v6only bool isConnectNotified bool @@ -705,7 +705,7 @@ func (e *endpoint) UniqueID() uint64 { // // If userMSS is non-zero and is not greater than the maximum possible MSS for // r, it will be used; otherwise, the maximum possible MSS will be used. -func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { +func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The maximum possible MSS is dependent on the route. // TODO(b/143359391): Respect TCP Min and Max size. maxMSS := uint16(r.MTU() - header.TCPMinimumSize) @@ -1173,7 +1173,11 @@ func (e *endpoint) cleanupLocked() { e.boundPortFlags = ports.Flags{} e.boundDest = tcpip.FullAddress{} - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2329aca4b..672159eed 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error ttl = route.DefaultTTL() } - return sendTCP(&route, tcpFields{ + return sendTCP(route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a7a405dcb..9d33a694b 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -99,7 +99,7 @@ type endpoint struct { sndBufSize int sndBufSizeMax int state EndpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` dstPort uint16 v6only bool ttl uint8 @@ -259,7 +259,10 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. e.state = StateClosed @@ -368,7 +371,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -387,7 +390,7 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Find a route to the desired destination. r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { - return stack.Route{}, 0, err + return nil, 0, err } return r, nicID, nil } @@ -459,7 +462,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) var dstPort uint16 if to == nil { - route = &e.route + route = e.route dstPort = e.dstPort resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { // Promote lock to exclusive if using a shared route, given that it may @@ -512,7 +515,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r dstPort = dst.Port resolve = route.Resolve } @@ -1083,7 +1086,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.ID = id e.boundBindToDevice = btd e.route.Release() - e.route = stack.Route{} + e.route = nil e.dstPort = 0 return nil |