From 477d7e5e10378e2f80f21ac9f536d12c4b94d7ce Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 15 Sep 2021 23:48:05 -0700 Subject: Annotate checklocks on mutex protected fields ...to catch lock-related bugs in nogo tests. Also update the endpoint's state field to be accessed while the mutex is held instead of requiring atomic operations as nothing needs to call the State method while the mutex is held. Updates #6566. PiperOrigin-RevId: 397010316 --- pkg/tcpip/transport/internal/network/endpoint.go | 114 ++++++++++++--------- .../transport/internal/network/endpoint_state.go | 4 +- 2 files changed, 66 insertions(+), 52 deletions(-) (limited to 'pkg') diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 09b629022..3cb821475 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -18,7 +18,6 @@ package network import ( "fmt" - "sync/atomic" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,31 +37,41 @@ type Endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - // state holds a transport.DatagramBasedEndpointState. - // - // state must be read from/written to atomically. - state uint32 - - // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + // +checklocks:mu + state transport.DatagramEndpointState + // +checklocks:mu wasBound bool - info stack.TransportEndpointInfo + // +checklocks:mu + info stack.TransportEndpointInfo // owner is the owner of transmitted packets. - owner tcpip.PacketOwner - writeShutdown bool - effectiveNetProto tcpip.NetworkProtocolNumber - connectedRoute *stack.Route `state:"manual"` + // + // +checklocks:mu + owner tcpip.PacketOwner + // +checklocks:mu + writeShutdown bool + // +checklocks:mu + effectiveNetProto tcpip.NetworkProtocolNumber + // +checklocks:mu + connectedRoute *stack.Route `state:"manual"` + // +checklocks:mu multicastMemberships map[multicastMembership]struct{} // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu ttl uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastTTL uint8 // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + // +checklocks:mu multicastNICID tcpip.NICID - ipv4TOS uint8 - ipv6TClass uint8 + // +checklocks:mu + ipv4TOS uint8 + // +checklocks:mu + ipv6TClass uint8 } // +stateify savable @@ -73,8 +82,11 @@ type multicastMembership struct { // Init initializes the endpoint. func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { - if e.multicastMemberships != nil { - panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + e.mu.Lock() + memberships := e.multicastMemberships + e.mu.Unlock() + if memberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships)) } switch netProto { @@ -89,8 +101,7 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: uint32(transport.DatagramEndpointStateInitial), - + state: transport.DatagramEndpointStateInitial, info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -107,14 +118,11 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } -// setState sets the state of the endpoint. -func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { - atomic.StoreUint32(&e.state, uint32(state)) -} - // State returns the state of the endpoint. func (e *Endpoint) State() transport.DatagramEndpointState { - return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) + e.mu.RLock() + defer e.mu.RUnlock() + return e.state } // Close cleans the endpoint's resources and leaves the endpoint in a closed @@ -123,7 +131,7 @@ func (e *Endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.State() == transport.DatagramEndpointStateClosed { + if e.state == transport.DatagramEndpointStateClosed { return } @@ -137,7 +145,7 @@ func (e *Endpoint) Close() { e.connectedRoute = nil } - e.setEndpointState(transport.DatagramEndpointStateClosed) + e.state = transport.DatagramEndpointStateClosed } // SetOwner sets the owner of transmitted packets. @@ -218,7 +226,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrInvalidOptionValue{} } - if e.State() == transport.DatagramEndpointStateClosed { + if e.state == transport.DatagramEndpointStateClosed { return WriteContext{}, &tcpip.ErrInvalidEndpointState{} } @@ -230,7 +238,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return WriteContext{}, &tcpip.ErrDestinationRequired{} } @@ -253,12 +261,12 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext nicID = e.info.RegisterNICID } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) + dst, netProto, err := e.checkV4MappedRLocked(*opts.To) if err != nil { return WriteContext{}, err } - route, _, err = e.connectRoute(nicID, dst, netProto) + route, _, err = e.connectRouteRLocked(nicID, dst, netProto) if err != nil { return WriteContext{}, err } @@ -293,7 +301,7 @@ func (e *Endpoint) Disconnect() { e.mu.Lock() defer e.mu.Unlock() - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return } @@ -302,20 +310,23 @@ func (e *Endpoint) Disconnect() { e.info.ID = stack.TransportEndpointID{ LocalAddress: e.info.BindAddr, } - e.setEndpointState(transport.DatagramEndpointStateBound) + e.state = transport.DatagramEndpointStateBound } else { e.info.ID = stack.TransportEndpointID{} - e.setEndpointState(transport.DatagramEndpointStateInitial) + e.state = transport.DatagramEndpointStateInitial } e.connectedRoute.Release() e.connectedRoute = nil } -// connectRoute establishes a route to the specified interface or the +// connectRouteRLocked establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.info.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -360,7 +371,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. defer e.mu.Unlock() nicID := addr.NIC - switch e.State() { + switch e.state { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: if e.info.BindNICID == 0 { @@ -376,12 +387,12 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4MappedRLocked(addr) if err != nil { return err } - r, nicID, err := e.connectRoute(nicID, addr, netProto) + r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto) if err != nil { return err } @@ -390,7 +401,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. LocalAddress: e.info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } - if e.State() == transport.DatagramEndpointStateInitial { + if e.state == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } @@ -406,7 +417,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.info.ID = id e.info.RegisterNICID = nicID e.effectiveNetProto = netProto - e.setEndpointState(transport.DatagramEndpointStateConnected) + e.state = transport.DatagramEndpointStateConnected return nil } @@ -415,7 +426,7 @@ func (e *Endpoint) Shutdown() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - switch state := e.State(); state { + switch state := e.state; state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: @@ -426,9 +437,12 @@ func (e *Endpoint) Shutdown() tcpip.Error { } } -// checkV4MappedLocked determines the effective network protocol and converts +// checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { +// +// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. +// +checklocks:e.mu +func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -460,11 +474,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto // Don't allow binding once endpoint is not in the initial state // anymore. - if e.State() != transport.DatagramEndpointStateInitial { + if e.state != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) + addr, netProto, err := e.checkV4MappedRLocked(addr) if err != nil { return err } @@ -490,7 +504,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto e.info.RegisterNICID = nicID e.info.BindAddr = addr.Addr e.effectiveNetProto = netProto - e.setEndpointState(transport.DatagramEndpointStateBound) + e.state = transport.DatagramEndpointStateBound return nil } @@ -507,7 +521,7 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { defer e.mu.RUnlock() addr := e.info.BindAddr - if e.State() == transport.DatagramEndpointStateConnected { + if e.state == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } @@ -522,7 +536,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { e.mu.RLock() defer e.mu.RUnlock() - if e.State() != transport.DatagramEndpointStateConnected { + if e.state != transport.DatagramEndpointStateConnected { return tcpip.FullAddress{}, false } @@ -610,7 +624,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) + fa, netProto, err := e.checkV4MappedRLocked(fa) if err != nil { return err } diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 858007156..173197512 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,7 +35,7 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } - switch state := e.State(); state { + switch e.state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { @@ -51,6 +51,6 @@ func (e *Endpoint) Resume(s *stack.Stack) { panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: - panic(fmt.Sprintf("unhandled state = %s", state)) + panic(fmt.Sprintf("unhandled state = %s", e.state)) } } -- cgit v1.2.3