summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2021-09-17 22:41:44 +0000
committergVisor bot <gvisor-bot@google.com>2021-09-17 22:41:44 +0000
commit09944a80063ffda7bca135a444c76fe613be67b5 (patch)
tree7099fcf7a19670b81c992fda2930d53bae5d2c6d
parent6392b0f3bea052af0de9d95677233dd9e442dbd5 (diff)
parent7dacdbef528f7b556f23c1b02a360363dc556e31 (diff)
Merge release-20210906.0-40-g7dacdbef5 (automated)
-rw-r--r--pkg/sentry/socket/netstack/netstack.go5
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go168
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go16
-rw-r--r--pkg/tcpip/transport/internal/network/network_state_autogen.go56
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go23
-rw-r--r--pkg/tcpip/transport/packet/packet_state_autogen.go57
6 files changed, 186 insertions, 139 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f79bda922..aa081e90d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -672,13 +672,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
- if a.Protocol != uint16(s.protocol) {
- return syserr.ErrInvalidArgument
- }
-
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: socket.Ntohs(a.Protocol),
}
} else {
if s.minSockAddrLen() > len(sockaddr) {
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index 3cb821475..e3094f59f 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -18,6 +18,7 @@ package network
import (
"fmt"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -39,11 +40,7 @@ type Endpoint struct {
mu sync.RWMutex `state:"nosave"`
// +checklocks:mu
- state transport.DatagramEndpointState
- // +checklocks:mu
wasBound bool
- // +checklocks:mu
- info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
//
// +checklocks:mu
@@ -72,6 +69,34 @@ type Endpoint struct {
ipv4TOS uint8
// +checklocks:mu
ipv6TClass uint8
+
+ // Lock ordering: mu > infoMu.
+ infoMu sync.RWMutex `state:"nosave"`
+ // info has a dedicated mutex so that we can avoid lock ordering violations
+ // when reading the endpoint's info. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling Info()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setInfo.
+ //
+ // +checklocks:infoMu
+ info stack.TransportEndpointInfo
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be accessed with atomics so that we can avoid lock ordering
+ // violations when reading the state. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling State()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setEndpointState.
+ //
+ // +checkatomics
+ state uint32
}
// +stateify savable
@@ -101,7 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
netProto: netProto,
transProto: transProto,
- state: transport.DatagramEndpointStateInitial,
info: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: transProto,
@@ -111,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
multicastTTL: 1,
multicastMemberships: make(map[multicastMembership]struct{}),
}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
}
// NetProto returns the network protocol the endpoint was initialized with.
@@ -118,11 +146,19 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
return e.netProto
}
+// setEndpointState sets the state of the endpoint.
+//
+// e.mu must be held to synchronize changes to state with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
+ atomic.StoreUint32(&e.state, uint32(state))
+}
+
// State returns the state of the endpoint.
func (e *Endpoint) State() transport.DatagramEndpointState {
- e.mu.RLock()
- defer e.mu.RUnlock()
- return e.state
+ return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
}
// Close cleans the endpoint's resources and leaves the endpoint in a closed
@@ -131,7 +167,7 @@ func (e *Endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state == transport.DatagramEndpointStateClosed {
+ if e.State() == transport.DatagramEndpointStateClosed {
return
}
@@ -145,7 +181,7 @@ func (e *Endpoint) Close() {
e.connectedRoute = nil
}
- e.state = transport.DatagramEndpointStateClosed
+ e.setEndpointState(transport.DatagramEndpointStateClosed)
}
// SetOwner sets the owner of transmitted packets.
@@ -226,7 +262,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
}
- if e.state == transport.DatagramEndpointStateClosed {
+ if e.State() == transport.DatagramEndpointStateClosed {
return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
}
@@ -238,7 +274,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if opts.To == nil {
// If the user doesn't specify a destination, they should have
// connected to another address.
- if e.state != transport.DatagramEndpointStateConnected {
+ if e.State() != transport.DatagramEndpointStateConnected {
return WriteContext{}, &tcpip.ErrDestinationRequired{}
}
@@ -250,18 +286,19 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
- if e.info.BindNICID != 0 {
- if nicID != 0 && nicID != e.info.BindNICID {
+ info := e.Info()
+ if info.BindNICID != 0 {
+ if nicID != 0 && nicID != info.BindNICID {
return WriteContext{}, &tcpip.ErrNoRoute{}
}
- nicID = e.info.BindNICID
+ nicID = info.BindNICID
}
if nicID == 0 {
- nicID = e.info.RegisterNICID
+ nicID = info.RegisterNICID
}
- dst, netProto, err := e.checkV4MappedRLocked(*opts.To)
+ dst, netProto, err := e.checkV4Mapped(*opts.To)
if err != nil {
return WriteContext{}, err
}
@@ -301,20 +338,22 @@ func (e *Endpoint) Disconnect() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != transport.DatagramEndpointStateConnected {
+ if e.State() != transport.DatagramEndpointStateConnected {
return
}
+ info := e.Info()
// Exclude ephemerally bound endpoints.
if e.wasBound {
- e.info.ID = stack.TransportEndpointID{
- LocalAddress: e.info.BindAddr,
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: info.BindAddr,
}
- e.state = transport.DatagramEndpointStateBound
+ e.setEndpointState(transport.DatagramEndpointStateBound)
} else {
- e.info.ID = stack.TransportEndpointID{}
- e.state = transport.DatagramEndpointStateInitial
+ info.ID = stack.TransportEndpointID{}
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
}
+ e.setInfo(info)
e.connectedRoute.Release()
e.connectedRoute = nil
@@ -327,7 +366,7 @@ func (e *Endpoint) Disconnect() {
// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
// +checklocks:e.mu
func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.info.ID.LocalAddress
+ localAddr := e.Info().ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
@@ -370,24 +409,25 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.mu.Lock()
defer e.mu.Unlock()
+ info := e.Info()
nicID := addr.NIC
- switch e.state {
+ switch e.State() {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
- if e.info.BindNICID == 0 {
+ if info.BindNICID == 0 {
break
}
- if nicID != 0 && nicID != e.info.BindNICID {
+ if nicID != 0 && nicID != info.BindNICID {
return &tcpip.ErrInvalidEndpointState{}
}
- nicID = e.info.BindNICID
+ nicID = info.BindNICID
default:
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedRLocked(addr)
+ addr, netProto, err := e.checkV4Mapped(addr)
if err != nil {
return err
}
@@ -398,14 +438,14 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
}
id := stack.TransportEndpointID{
- LocalAddress: e.info.ID.LocalAddress,
+ LocalAddress: info.ID.LocalAddress,
RemoteAddress: r.RemoteAddress(),
}
- if e.state == transport.DatagramEndpointStateInitial {
+ if e.State() == transport.DatagramEndpointStateInitial {
id.LocalAddress = r.LocalAddress()
}
- if err := f(r.NetProto(), e.info.ID, id); err != nil {
+ if err := f(r.NetProto(), info.ID, id); err != nil {
return err
}
@@ -414,10 +454,11 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.connectedRoute.Release()
}
e.connectedRoute = r
- e.info.ID = id
- e.info.RegisterNICID = nicID
+ info.ID = id
+ info.RegisterNICID = nicID
+ e.setInfo(info)
e.effectiveNetProto = netProto
- e.state = transport.DatagramEndpointStateConnected
+ e.setEndpointState(transport.DatagramEndpointStateConnected)
return nil
}
@@ -426,7 +467,7 @@ func (e *Endpoint) Shutdown() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- switch state := e.state; state {
+ switch state := e.State(); state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
@@ -439,11 +480,9 @@ func (e *Endpoint) Shutdown() tcpip.Error {
// checkV4MappedRLocked determines the effective network protocol and converts
// addr to its canonical form.
-//
-// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
-// +checklocks:e.mu
-func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ info := e.Info()
+ unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -474,11 +513,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != transport.DatagramEndpointStateInitial {
+ if e.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedRLocked(addr)
+ addr, netProto, err := e.checkV4Mapped(addr)
if err != nil {
return err
}
@@ -497,14 +536,16 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
e.wasBound = true
- e.info.ID = stack.TransportEndpointID{
+ info := e.Info()
+ info.ID = stack.TransportEndpointID{
LocalAddress: addr.Addr,
}
- e.info.BindNICID = addr.NIC
- e.info.RegisterNICID = nicID
- e.info.BindAddr = addr.Addr
+ info.BindNICID = addr.NIC
+ info.RegisterNICID = nicID
+ info.BindAddr = addr.Addr
+ e.setInfo(info)
e.effectiveNetProto = netProto
- e.state = transport.DatagramEndpointStateBound
+ e.setEndpointState(transport.DatagramEndpointStateBound)
return nil
}
@@ -520,13 +561,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.info.BindAddr
- if e.state == transport.DatagramEndpointStateConnected {
+ info := e.Info()
+ addr := info.BindAddr
+ if e.State() == transport.DatagramEndpointStateConnected {
addr = e.connectedRoute.LocalAddress()
}
return tcpip.FullAddress{
- NIC: e.info.RegisterNICID,
+ NIC: info.RegisterNICID,
Addr: addr,
}
}
@@ -536,13 +578,13 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != transport.DatagramEndpointStateConnected {
+ if e.State() != transport.DatagramEndpointStateConnected {
return tcpip.FullAddress{}, false
}
return tcpip.FullAddress{
Addr: e.connectedRoute.RemoteAddress(),
- NIC: e.info.RegisterNICID,
+ NIC: e.Info().RegisterNICID,
}, true
}
@@ -624,7 +666,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedRLocked(fa)
+ fa, netProto, err := e.checkV4Mapped(fa)
if err != nil {
return err
}
@@ -648,7 +690,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
- if e.info.BindNICID != 0 && e.info.BindNICID != nic {
+ if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
return &tcpip.ErrInvalidEndpointState{}
}
@@ -751,7 +793,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
// Info returns a copy of the endpoint info.
func (e *Endpoint) Info() stack.TransportEndpointInfo {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.infoMu.RLock()
+ defer e.infoMu.RUnlock()
return e.info
}
+
+// setInfo sets the endpoint's info.
+//
+// e.mu must be held to synchronize changes to info with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
+ e.infoMu.Lock()
+ defer e.infoMu.Unlock()
+ e.info = info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
index 173197512..68bd1fbf6 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_state.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -35,22 +35,24 @@ func (e *Endpoint) Resume(s *stack.Stack) {
}
}
- switch e.state {
+ info := e.Info()
+
+ switch state := e.State(); state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound:
- if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
- if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 {
- panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress))
+ if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress))
}
}
case transport.DatagramEndpointStateConnected:
var err tcpip.Error
multicastLoop := e.ops.GetMulticastLoop()
- e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
if err != nil {
- panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
}
default:
- panic(fmt.Sprintf("unhandled state = %s", e.state))
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/internal/network/network_state_autogen.go b/pkg/tcpip/transport/internal/network/network_state_autogen.go
index f72149c1c..1515c8632 100644
--- a/pkg/tcpip/transport/internal/network/network_state_autogen.go
+++ b/pkg/tcpip/transport/internal/network/network_state_autogen.go
@@ -15,9 +15,7 @@ func (e *Endpoint) StateFields() []string {
"ops",
"netProto",
"transProto",
- "state",
"wasBound",
- "info",
"owner",
"writeShutdown",
"effectiveNetProto",
@@ -28,6 +26,8 @@ func (e *Endpoint) StateFields() []string {
"multicastNICID",
"ipv4TOS",
"ipv6TClass",
+ "info",
+ "state",
}
}
@@ -39,19 +39,19 @@ func (e *Endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(0, &e.ops)
stateSinkObject.Save(1, &e.netProto)
stateSinkObject.Save(2, &e.transProto)
- stateSinkObject.Save(3, &e.state)
- stateSinkObject.Save(4, &e.wasBound)
- stateSinkObject.Save(5, &e.info)
- stateSinkObject.Save(6, &e.owner)
- stateSinkObject.Save(7, &e.writeShutdown)
- stateSinkObject.Save(8, &e.effectiveNetProto)
- stateSinkObject.Save(9, &e.multicastMemberships)
- stateSinkObject.Save(10, &e.ttl)
- stateSinkObject.Save(11, &e.multicastTTL)
- stateSinkObject.Save(12, &e.multicastAddr)
- stateSinkObject.Save(13, &e.multicastNICID)
- stateSinkObject.Save(14, &e.ipv4TOS)
- stateSinkObject.Save(15, &e.ipv6TClass)
+ stateSinkObject.Save(3, &e.wasBound)
+ stateSinkObject.Save(4, &e.owner)
+ stateSinkObject.Save(5, &e.writeShutdown)
+ stateSinkObject.Save(6, &e.effectiveNetProto)
+ stateSinkObject.Save(7, &e.multicastMemberships)
+ stateSinkObject.Save(8, &e.ttl)
+ stateSinkObject.Save(9, &e.multicastTTL)
+ stateSinkObject.Save(10, &e.multicastAddr)
+ stateSinkObject.Save(11, &e.multicastNICID)
+ stateSinkObject.Save(12, &e.ipv4TOS)
+ stateSinkObject.Save(13, &e.ipv6TClass)
+ stateSinkObject.Save(14, &e.info)
+ stateSinkObject.Save(15, &e.state)
}
func (e *Endpoint) afterLoad() {}
@@ -61,19 +61,19 @@ func (e *Endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &e.ops)
stateSourceObject.Load(1, &e.netProto)
stateSourceObject.Load(2, &e.transProto)
- stateSourceObject.Load(3, &e.state)
- stateSourceObject.Load(4, &e.wasBound)
- stateSourceObject.Load(5, &e.info)
- stateSourceObject.Load(6, &e.owner)
- stateSourceObject.Load(7, &e.writeShutdown)
- stateSourceObject.Load(8, &e.effectiveNetProto)
- stateSourceObject.Load(9, &e.multicastMemberships)
- stateSourceObject.Load(10, &e.ttl)
- stateSourceObject.Load(11, &e.multicastTTL)
- stateSourceObject.Load(12, &e.multicastAddr)
- stateSourceObject.Load(13, &e.multicastNICID)
- stateSourceObject.Load(14, &e.ipv4TOS)
- stateSourceObject.Load(15, &e.ipv6TClass)
+ stateSourceObject.Load(3, &e.wasBound)
+ stateSourceObject.Load(4, &e.owner)
+ stateSourceObject.Load(5, &e.writeShutdown)
+ stateSourceObject.Load(6, &e.effectiveNetProto)
+ stateSourceObject.Load(7, &e.multicastMemberships)
+ stateSourceObject.Load(8, &e.ttl)
+ stateSourceObject.Load(9, &e.multicastTTL)
+ stateSourceObject.Load(10, &e.multicastAddr)
+ stateSourceObject.Load(11, &e.multicastNICID)
+ stateSourceObject.Load(12, &e.ipv4TOS)
+ stateSourceObject.Load(13, &e.ipv6TClass)
+ stateSourceObject.Load(14, &e.info)
+ stateSourceObject.Load(15, &e.state)
}
func (m *multicastMembership) StateTypeName() string {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 2c9786175..1f30e5adb 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -59,13 +59,11 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions
@@ -84,6 +82,8 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// +checklocks:mu
+ netProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
closed bool
// +checklocks:mu
bound bool
@@ -98,10 +98,7 @@ type endpoint struct {
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
+ stack: s,
cooked: cooked,
netProto: netProto,
waiterQueue: waiterQueue,
@@ -214,13 +211,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
+ proto := ep.netProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
}
var remote tcpip.LinkAddress
- proto := ep.netProto
if to := opts.To; to != nil {
remote = tcpip.LinkAddress(to.Addr)
@@ -296,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto {
// If the NIC being bound is the same then just return success.
return nil
}
@@ -306,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.bound = false
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
ep.bound = true
ep.boundNIC = addr.NIC
+ ep.netProto = netProto
return nil
}
@@ -473,10 +472,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.netProto}
}
// Stats returns a pointer to the endpoint stats.
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
index 6017e9e8b..75e6a2070 100644
--- a/pkg/tcpip/transport/packet/packet_state_autogen.go
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -54,9 +54,7 @@ func (ep *endpoint) StateTypeName() string {
func (ep *endpoint) StateFields() []string {
return []string{
- "TransportEndpointInfo",
"DefaultSocketOptionsHandler",
- "netProto",
"waiterQueue",
"cooked",
"ops",
@@ -64,6 +62,7 @@ func (ep *endpoint) StateFields() []string {
"rcvBufSize",
"rcvClosed",
"rcvDisabled",
+ "netProto",
"closed",
"bound",
"boundNIC",
@@ -74,38 +73,36 @@ func (ep *endpoint) StateFields() []string {
// +checklocksignore
func (ep *endpoint) StateSave(stateSinkObject state.Sink) {
ep.beforeSave()
- stateSinkObject.Save(0, &ep.TransportEndpointInfo)
- stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler)
- stateSinkObject.Save(2, &ep.netProto)
- stateSinkObject.Save(3, &ep.waiterQueue)
- stateSinkObject.Save(4, &ep.cooked)
- stateSinkObject.Save(5, &ep.ops)
- stateSinkObject.Save(6, &ep.rcvList)
- stateSinkObject.Save(7, &ep.rcvBufSize)
- stateSinkObject.Save(8, &ep.rcvClosed)
- stateSinkObject.Save(9, &ep.rcvDisabled)
- stateSinkObject.Save(10, &ep.closed)
- stateSinkObject.Save(11, &ep.bound)
- stateSinkObject.Save(12, &ep.boundNIC)
- stateSinkObject.Save(13, &ep.lastError)
+ stateSinkObject.Save(0, &ep.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &ep.waiterQueue)
+ stateSinkObject.Save(2, &ep.cooked)
+ stateSinkObject.Save(3, &ep.ops)
+ stateSinkObject.Save(4, &ep.rcvList)
+ stateSinkObject.Save(5, &ep.rcvBufSize)
+ stateSinkObject.Save(6, &ep.rcvClosed)
+ stateSinkObject.Save(7, &ep.rcvDisabled)
+ stateSinkObject.Save(8, &ep.netProto)
+ stateSinkObject.Save(9, &ep.closed)
+ stateSinkObject.Save(10, &ep.bound)
+ stateSinkObject.Save(11, &ep.boundNIC)
+ stateSinkObject.Save(12, &ep.lastError)
}
// +checklocksignore
func (ep *endpoint) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(0, &ep.TransportEndpointInfo)
- stateSourceObject.Load(1, &ep.DefaultSocketOptionsHandler)
- stateSourceObject.Load(2, &ep.netProto)
- stateSourceObject.Load(3, &ep.waiterQueue)
- stateSourceObject.Load(4, &ep.cooked)
- stateSourceObject.Load(5, &ep.ops)
- stateSourceObject.Load(6, &ep.rcvList)
- stateSourceObject.Load(7, &ep.rcvBufSize)
- stateSourceObject.Load(8, &ep.rcvClosed)
- stateSourceObject.Load(9, &ep.rcvDisabled)
- stateSourceObject.Load(10, &ep.closed)
- stateSourceObject.Load(11, &ep.bound)
- stateSourceObject.Load(12, &ep.boundNIC)
- stateSourceObject.Load(13, &ep.lastError)
+ stateSourceObject.Load(0, &ep.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &ep.waiterQueue)
+ stateSourceObject.Load(2, &ep.cooked)
+ stateSourceObject.Load(3, &ep.ops)
+ stateSourceObject.Load(4, &ep.rcvList)
+ stateSourceObject.Load(5, &ep.rcvBufSize)
+ stateSourceObject.Load(6, &ep.rcvClosed)
+ stateSourceObject.Load(7, &ep.rcvDisabled)
+ stateSourceObject.Load(8, &ep.netProto)
+ stateSourceObject.Load(9, &ep.closed)
+ stateSourceObject.Load(10, &ep.bound)
+ stateSourceObject.Load(11, &ep.boundNIC)
+ stateSourceObject.Load(12, &ep.lastError)
stateSourceObject.AfterLoad(ep.afterLoad)
}