summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2021-09-16 06:55:28 +0000
committergVisor bot <gvisor-bot@google.com>2021-09-16 06:55:28 +0000
commit812a722c54a8c359f77d1bdfea9f63b810e5a5a8 (patch)
tree1a28a3763630ed49889424a70c89af65c437fdea /pkg
parent9df83e598e95cb0251112d1fca1c4916e74e2f6a (diff)
parent477d7e5e10378e2f80f21ac9f536d12c4b94d7ce (diff)
Merge release-20210906.0-30-g477d7e5e1 (automated)
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go114
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go4
2 files changed, 66 insertions, 52 deletions
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index 09b629022..3cb821475 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -18,7 +18,6 @@ package network
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,31 +37,41 @@ type Endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- // state holds a transport.DatagramBasedEndpointState.
- //
- // state must be read from/written to atomically.
- state uint32
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ state transport.DatagramEndpointState
+ // +checklocks:mu
wasBound bool
- info stack.TransportEndpointInfo
+ // +checklocks:mu
+ info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
- owner tcpip.PacketOwner
- writeShutdown bool
- effectiveNetProto tcpip.NetworkProtocolNumber
- connectedRoute *stack.Route `state:"manual"`
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
ttl uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastTTL uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastAddr tcpip.Address
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastNICID tcpip.NICID
- ipv4TOS uint8
- ipv6TClass uint8
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
}
// +stateify savable
@@ -73,8 +82,11 @@ type multicastMembership struct {
// Init initializes the endpoint.
func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
- if e.multicastMemberships != nil {
- panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
}
switch netProto {
@@ -89,8 +101,7 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
netProto: netProto,
transProto: transProto,
- state: uint32(transport.DatagramEndpointStateInitial),
-
+ state: transport.DatagramEndpointStateInitial,
info: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: transProto,
@@ -107,14 +118,11 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
return e.netProto
}
-// setState sets the state of the endpoint.
-func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
// State returns the state of the endpoint.
func (e *Endpoint) State() transport.DatagramEndpointState {
- return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.state
}
// Close cleans the endpoint's resources and leaves the endpoint in a closed
@@ -123,7 +131,7 @@ func (e *Endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.State() == transport.DatagramEndpointStateClosed {
+ if e.state == transport.DatagramEndpointStateClosed {
return
}
@@ -137,7 +145,7 @@ func (e *Endpoint) Close() {
e.connectedRoute = nil
}
- e.setEndpointState(transport.DatagramEndpointStateClosed)
+ e.state = transport.DatagramEndpointStateClosed
}
// SetOwner sets the owner of transmitted packets.
@@ -218,7 +226,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
}
- if e.State() == transport.DatagramEndpointStateClosed {
+ if e.state == transport.DatagramEndpointStateClosed {
return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
}
@@ -230,7 +238,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if opts.To == nil {
// If the user doesn't specify a destination, they should have
// connected to another address.
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return WriteContext{}, &tcpip.ErrDestinationRequired{}
}
@@ -253,12 +261,12 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
nicID = e.info.RegisterNICID
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ dst, netProto, err := e.checkV4MappedRLocked(*opts.To)
if err != nil {
return WriteContext{}, err
}
- route, _, err = e.connectRoute(nicID, dst, netProto)
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
if err != nil {
return WriteContext{}, err
}
@@ -293,7 +301,7 @@ func (e *Endpoint) Disconnect() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return
}
@@ -302,20 +310,23 @@ func (e *Endpoint) Disconnect() {
e.info.ID = stack.TransportEndpointID{
LocalAddress: e.info.BindAddr,
}
- e.setEndpointState(transport.DatagramEndpointStateBound)
+ e.state = transport.DatagramEndpointStateBound
} else {
e.info.ID = stack.TransportEndpointID{}
- e.setEndpointState(transport.DatagramEndpointStateInitial)
+ e.state = transport.DatagramEndpointStateInitial
}
e.connectedRoute.Release()
e.connectedRoute = nil
}
-// connectRoute establishes a route to the specified interface or the
+// connectRouteRLocked establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
localAddr := e.info.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -360,7 +371,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
defer e.mu.Unlock()
nicID := addr.NIC
- switch e.State() {
+ switch e.state {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
if e.info.BindNICID == 0 {
@@ -376,12 +387,12 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4MappedRLocked(addr)
if err != nil {
return err
}
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
if err != nil {
return err
}
@@ -390,7 +401,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
LocalAddress: e.info.ID.LocalAddress,
RemoteAddress: r.RemoteAddress(),
}
- if e.State() == transport.DatagramEndpointStateInitial {
+ if e.state == transport.DatagramEndpointStateInitial {
id.LocalAddress = r.LocalAddress()
}
@@ -406,7 +417,7 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.info.ID = id
e.info.RegisterNICID = nicID
e.effectiveNetProto = netProto
- e.setEndpointState(transport.DatagramEndpointStateConnected)
+ e.state = transport.DatagramEndpointStateConnected
return nil
}
@@ -415,7 +426,7 @@ func (e *Endpoint) Shutdown() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- switch state := e.State(); state {
+ switch state := e.state; state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
@@ -426,9 +437,12 @@ func (e *Endpoint) Shutdown() tcpip.Error {
}
}
-// checkV4MappedLocked determines the effective network protocol and converts
+// checkV4MappedRLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -460,11 +474,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.State() != transport.DatagramEndpointStateInitial {
+ if e.state != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4MappedRLocked(addr)
if err != nil {
return err
}
@@ -490,7 +504,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
e.info.RegisterNICID = nicID
e.info.BindAddr = addr.Addr
e.effectiveNetProto = netProto
- e.setEndpointState(transport.DatagramEndpointStateBound)
+ e.state = transport.DatagramEndpointStateBound
return nil
}
@@ -507,7 +521,7 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
defer e.mu.RUnlock()
addr := e.info.BindAddr
- if e.State() == transport.DatagramEndpointStateConnected {
+ if e.state == transport.DatagramEndpointStateConnected {
addr = e.connectedRoute.LocalAddress()
}
@@ -522,7 +536,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.State() != transport.DatagramEndpointStateConnected {
+ if e.state != transport.DatagramEndpointStateConnected {
return tcpip.FullAddress{}, false
}
@@ -610,7 +624,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
+ fa, netProto, err := e.checkV4MappedRLocked(fa)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
index 858007156..173197512 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_state.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -35,7 +35,7 @@ func (e *Endpoint) Resume(s *stack.Stack) {
}
}
- switch state := e.State(); state {
+ switch e.state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound:
if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
@@ -51,6 +51,6 @@ func (e *Endpoint) Resume(s *stack.Stack) {
panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
}
default:
- panic(fmt.Sprintf("unhandled state = %s", state))
+ panic(fmt.Sprintf("unhandled state = %s", e.state))
}
}