summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-09-15 23:48:05 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-15 23:51:11 -0700
commit477d7e5e10378e2f80f21ac9f536d12c4b94d7ce (patch)
treebcda9c47bdaee2be8f4594f636d2271482d56a64
parenta8ad692fd36cbaf7f5a6b9af39d601053dbee338 (diff)
Annotate checklocks on mutex protected fields
...to catch lock-related bugs in nogo tests. Also update the endpoint's state field to be accessed while the mutex is held instead of requiring atomic operations as nothing needs to call the State method while the mutex is held. Updates #6566. PiperOrigin-RevId: 397010316
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go114
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go4
2 files changed, 66 insertions, 52 deletions
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index 09b629022..3cb821475 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -18,7 +18,6 @@ package network
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,31 +37,41 @@ type Endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- // state holds a transport.DatagramBasedEndpointState.
- //
- // state must be read from/written to atomically.
- state uint32
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ state transport.DatagramEndpointState
+ // +checklocks:mu
wasBound bool
- info stack.TransportEndpointInfo
+ // +checklocks:mu
+ info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
- owner tcpip.PacketOwner
- writeShutdown bool
- effectiveNetProto tcpip.NetworkProtocolNumber
- connectedRoute *stack.Route `state:"manual"`
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
ttl uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastTTL uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastAddr tcpip.Address
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastNICID tcpip.NICID
- ipv4TOS uint8
- ipv6TClass uint8
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
}
// +stateify savable
@@ -73,8 +82,11 @@ type multicastMembership struct {
// Init initializes the endpoint.
func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
- if e.multicastMemberships != nil {
- panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
}
switch netProto {
@@ -89,8 +101,7 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
netProto: netProto,
transProto: transProto,
- state: uint32(transport.DatagramEndpointStateInitial),
-
+ state: transport.DatagramEndpointStateInitial,
info: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: transProto,
@@ -107,14 +118,11 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
return e.netProto
}
-// setState sets the state of the endpoint.
-func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
// State returns the state of the endpoint.
func (e *Endpoint) State() transport.DatagramEndpointState {
- return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.state
}
// Close cleans the endpoint's resources and leaves the endpoint in a closed
@@ -123,7 +131,7 @@ func (e *Endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.State() == transport.DatagramEndpointStateClosed {
+ if e.state == transport.DatagramEndpointStateClosed {
return
}
@@ -137,7 +145,7 @@ func (e *Endpoint) Close() {
e.connectedRoute = nil
}
- e.setEndpointState(transport.DatagramEndpointStateClosed)
+ e.state = transport.DatagramEndpointStateClosed
}
// SetOwner sets the owner of transmitted packets.
@@ -218,7 +226,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
}
- if e.State() == transport.DatagramEndpointStateClosed {
+ if e.state == transport.DatagramEndpointStateClosed {
return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
}
@@ -230,7 +238,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if opts.To == nil {
// If the user doesn't specify a destination, they should have
// connected to another address.
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return WriteContext{}, &tcpip.ErrDestinationRequired{}
}
@@ -253,12 +261,12 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
nicID = e.info.RegisterNICID
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ dst, netProto, err := e.checkV4MappedRLocked(*opts.To)
if err != nil {
return WriteContext{}, err
}
- route, _, err = e.connectRoute(nicID, dst, netProto)
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
if err != nil {
return WriteContext{}, err
}
@@ -293,7 +301,7 @@ func (e *Endpoint) Disconnect() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return
}
@@ -302,20 +310,23 @@ func (e *Endpoint) Disconnect() {
e.info.ID = stack.TransportEndpointID{
LocalAddress: e.info.BindAddr,
}
- e.setEndpointState(transport.DatagramEndpointStateBound)
+ e.state = transport.DatagramEndpointStateBound
} else {
e.info.ID = stack.TransportEndpointID{}
- e.setEndpointState(transport.DatagramEndpointStateInitial)
+ e.state = transport.DatagramEndpointStateInitial
}
e.connectedRoute.Release()
e.connectedRoute = nil
}
-// connectRoute establishes a route to the specified interface or the
+// connectRouteRLocked establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
localAddr := e.info.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -360,7 +371,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
defer e.mu.Unlock()
nicID := addr.NIC
- switch e.State() {
+ switch e.state {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
if e.info.BindNICID == 0 {
@@ -376,12 +387,12 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4MappedRLocked(addr)
if err != nil {
return err
}
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
if err != nil {
return err
}
@@ -390,7 +401,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
LocalAddress: e.info.ID.LocalAddress,
RemoteAddress: r.RemoteAddress(),
}
- if e.State() == transport.DatagramEndpointStateInitial {
+ if e.state == transport.DatagramEndpointStateInitial {
id.LocalAddress = r.LocalAddress()
}
@@ -406,7 +417,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.info.ID = id
e.info.RegisterNICID = nicID
e.effectiveNetProto = netProto
- e.setEndpointState(transport.DatagramEndpointStateConnected)
+ e.state = transport.DatagramEndpointStateConnected
return nil
}
@@ -415,7 +426,7 @@ func (e *Endpoint) Shutdown() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- switch state := e.State(); state {
+ switch state := e.state; state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
@@ -426,9 +437,12 @@ func (e *Endpoint) Shutdown() tcpip.Error {
}
}
-// checkV4MappedLocked determines the effective network protocol and converts
+// checkV4MappedRLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -460,11 +474,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.State() != transport.DatagramEndpointStateInitial {
+ if e.state != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4MappedRLocked(addr)
if err != nil {
return err
}
@@ -490,7 +504,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
e.info.RegisterNICID = nicID
e.info.BindAddr = addr.Addr
e.effectiveNetProto = netProto
- e.setEndpointState(transport.DatagramEndpointStateBound)
+ e.state = transport.DatagramEndpointStateBound
return nil
}
@@ -507,7 +521,7 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
defer e.mu.RUnlock()
addr := e.info.BindAddr
- if e.State() == transport.DatagramEndpointStateConnected {
+ if e.state == transport.DatagramEndpointStateConnected {
addr = e.connectedRoute.LocalAddress()
}
@@ -522,7 +536,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return tcpip.FullAddress{}, false
}
@@ -610,7 +624,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
+ fa, netProto, err := e.checkV4MappedRLocked(fa)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
index 858007156..173197512 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_state.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -35,7 +35,7 @@ func (e *Endpoint) Resume(s *stack.Stack) {
}
}
- switch state := e.State(); state {
+ switch e.state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound:
if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
@@ -51,6 +51,6 @@ func (e *Endpoint) Resume(s *stack.Stack) {
panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
}
default:
- panic(fmt.Sprintf("unhandled state = %s", state))
+ panic(fmt.Sprintf("unhandled state = %s", e.state))
}
}