diff options
author | gVisor bot <gvisor-bot@google.com> | 2021-09-16 06:55:28 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-09-16 06:55:28 +0000 |
commit | 812a722c54a8c359f77d1bdfea9f63b810e5a5a8 (patch) | |
tree | 1a28a3763630ed49889424a70c89af65c437fdea /pkg/tcpip | |
parent | 9df83e598e95cb0251112d1fca1c4916e74e2f6a (diff) | |
parent | 477d7e5e10378e2f80f21ac9f536d12c4b94d7ce (diff) |
Merge release-20210906.0-30-g477d7e5e1 (automated)
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 114 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint_state.go | 4 |
2 files changed, 66 insertions, 52 deletions
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 09b629022..3cb821475 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -18,7 +18,6 @@ package network import ( "fmt" - "sync/atomic" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,31 +37,41 @@ type Endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - // state holds a transport.DatagramBasedEndpointState. - // - // state must be read from/written to atomically. - state uint32 - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + state transport.DatagramEndpointState + // +checklocks:mu wasBound bool - info stack.TransportEndpointInfo + // +checklocks:mu + info stack.TransportEndpointInfo // owner is the owner of transmitted packets. - owner tcpip.PacketOwner - writeShutdown bool - effectiveNetProto tcpip.NetworkProtocolNumber - connectedRoute *stack.Route `state:"manual"` + // + // +checklocks:mu + owner tcpip.PacketOwner + // +checklocks:mu + writeShutdown bool + // +checklocks:mu + effectiveNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu + connectedRoute *stack.Route `state:"manual"` + // +checklocks:mu multicastMemberships map[multicastMembership]struct{} // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu ttl uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastTTL uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastNICID tcpip.NICID - ipv4TOS uint8 - ipv6TClass uint8 + // +checklocks:mu + ipv4TOS uint8 + // +checklocks:mu + ipv6TClass uint8 } // +stateify savable @@ -73,8 +82,11 @@ type multicastMembership struct { // Init initializes the endpoint. func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { - if e.multicastMemberships != nil { - panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + e.mu.Lock() + memberships := e.multicastMemberships + e.mu.Unlock() + if memberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships)) } switch netProto { @@ -89,8 +101,7 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: uint32(transport.DatagramEndpointStateInitial), - + state: transport.DatagramEndpointStateInitial, info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -107,14 +118,11 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } -// setState sets the state of the endpoint. -func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { - atomic.StoreUint32(&e.state, uint32(state)) -} - // State returns the state of the endpoint. func (e *Endpoint) State() transport.DatagramEndpointState { - return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) + e.mu.RLock() + defer e.mu.RUnlock() + return e.state } // Close cleans the endpoint's resources and leaves the endpoint in a closed @@ -123,7 +131,7 @@ func (e *Endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.State() == transport.DatagramEndpointStateClosed { + if e.state == transport.DatagramEndpointStateClosed { return } @@ -137,7 +145,7 @@ func (e *Endpoint) Close() { e.connectedRoute = nil } - e.setEndpointState(transport.DatagramEndpointStateClosed) + e.state = transport.DatagramEndpointStateClosed } // SetOwner sets the owner of transmitted packets. @@ -218,7 +226,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrInvalidOptionValue{} } - if e.State() == transport.DatagramEndpointStateClosed { + if e.state == transport.DatagramEndpointStateClosed { return WriteContext{}, &tcpip.ErrInvalidEndpointState{} } @@ -230,7 +238,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return WriteContext{}, &tcpip.ErrDestinationRequired{} } @@ -253,12 +261,12 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext nicID = e.info.RegisterNICID } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) + dst, netProto, err := e.checkV4MappedRLocked(*opts.To) if err != nil { return WriteContext{}, err } - route, _, err = e.connectRoute(nicID, dst, netProto) + route, _, err = e.connectRouteRLocked(nicID, dst, netProto) if err != nil { return WriteContext{}, err } @@ -293,7 +301,7 @@ func (e *Endpoint) Disconnect() { e.mu.Lock() defer e.mu.Unlock() - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return } @@ -302,20 +310,23 @@ func (e *Endpoint) Disconnect() { e.info.ID = stack.TransportEndpointID{ LocalAddress: e.info.BindAddr, } - e.setEndpointState(transport.DatagramEndpointStateBound) + e.state = transport.DatagramEndpointStateBound } else { e.info.ID = stack.TransportEndpointID{} - e.setEndpointState(transport.DatagramEndpointStateInitial) + e.state = transport.DatagramEndpointStateInitial } e.connectedRoute.Release() e.connectedRoute = nil } -// connectRoute establishes a route to the specified interface or the +// connectRouteRLocked establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.info.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -360,7 +371,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. defer e.mu.Unlock() nicID := addr.NIC - switch e.State() { + switch e.state { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: if e.info.BindNICID == 0 { @@ -376,12 +387,12 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4MappedRLocked(addr) if err != nil { return err } - r, nicID, err := e.connectRoute(nicID, addr, netProto) + r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto) if err != nil { return err } @@ -390,7 +401,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. LocalAddress: e.info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } - if e.State() == transport.DatagramEndpointStateInitial { + if e.state == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } @@ -406,7 +417,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.info.ID = id e.info.RegisterNICID = nicID e.effectiveNetProto = netProto - e.setEndpointState(transport.DatagramEndpointStateConnected) + e.state = transport.DatagramEndpointStateConnected return nil } @@ -415,7 +426,7 @@ func (e *Endpoint) Shutdown() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - switch state := e.State(); state { + switch state := e.state; state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: @@ -426,9 +437,12 @@ func (e *Endpoint) Shutdown() tcpip.Error { } } -// checkV4MappedLocked determines the effective network protocol and converts +// checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -460,11 +474,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto // Don't allow binding once endpoint is not in the initial state // anymore. - if e.State() != transport.DatagramEndpointStateInitial { + if e.state != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4MappedRLocked(addr) if err != nil { return err } @@ -490,7 +504,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto e.info.RegisterNICID = nicID e.info.BindAddr = addr.Addr e.effectiveNetProto = netProto - e.setEndpointState(transport.DatagramEndpointStateBound) + e.state = transport.DatagramEndpointStateBound return nil } @@ -507,7 +521,7 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { defer e.mu.RUnlock() addr := e.info.BindAddr - if e.State() == transport.DatagramEndpointStateConnected { + if e.state == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } @@ -522,7 +536,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { e.mu.RLock() defer e.mu.RUnlock() - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return tcpip.FullAddress{}, false } @@ -610,7 +624,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) + fa, netProto, err := e.checkV4MappedRLocked(fa) if err != nil { return err } diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 858007156..173197512 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,7 +35,7 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } - switch state := e.State(); state { + switch e.state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { @@ -51,6 +51,6 @@ func (e *Endpoint) Resume(s *stack.Stack) { panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: - panic(fmt.Sprintf("unhandled state = %s", state)) + panic(fmt.Sprintf("unhandled state = %s", e.state)) } } |