diff options
author | gVisor bot <gvisor-bot@google.com> | 2021-09-17 22:41:44 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-09-17 22:41:44 +0000 |
commit | 09944a80063ffda7bca135a444c76fe613be67b5 (patch) | |
tree | 7099fcf7a19670b81c992fda2930d53bae5d2c6d | |
parent | 6392b0f3bea052af0de9d95677233dd9e442dbd5 (diff) | |
parent | 7dacdbef528f7b556f23c1b02a360363dc556e31 (diff) |
Merge release-20210906.0-40-g7dacdbef5 (automated)
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 168 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint_state.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/network_state_autogen.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/packet_state_autogen.go | 57 |
6 files changed, 186 insertions, 139 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f79bda922..aa081e90d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -672,13 +672,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) - if a.Protocol != uint16(s.protocol) { - return syserr.ErrInvalidArgument - } - addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: socket.Ntohs(a.Protocol), } } else { if s.minSockAddrLen() > len(sockaddr) { diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 3cb821475..e3094f59f 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -18,6 +18,7 @@ package network import ( "fmt" + "sync/atomic" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -39,11 +40,7 @@ type Endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu - state transport.DatagramEndpointState - // +checklocks:mu wasBound bool - // +checklocks:mu - info stack.TransportEndpointInfo // owner is the owner of transmitted packets. // // +checklocks:mu @@ -72,6 +69,34 @@ type Endpoint struct { ipv4TOS uint8 // +checklocks:mu ipv6TClass uint8 + + // Lock ordering: mu > infoMu. + infoMu sync.RWMutex `state:"nosave"` + // info has a dedicated mutex so that we can avoid lock ordering violations + // when reading the endpoint's info. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling Info() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setInfo. + // + // +checklocks:infoMu + info stack.TransportEndpointInfo + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be accessed with atomics so that we can avoid lock ordering + // violations when reading the state. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling State() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setEndpointState. + // + // +checkatomics + state uint32 } // +stateify savable @@ -101,7 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: transport.DatagramEndpointStateInitial, info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -111,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), } + + e.mu.Lock() + defer e.mu.Unlock() + e.setEndpointState(transport.DatagramEndpointStateInitial) } // NetProto returns the network protocol the endpoint was initialized with. @@ -118,11 +146,19 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } +// setEndpointState sets the state of the endpoint. +// +// e.mu must be held to synchronize changes to state with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { + atomic.StoreUint32(&e.state, uint32(state)) +} + // State returns the state of the endpoint. func (e *Endpoint) State() transport.DatagramEndpointState { - e.mu.RLock() - defer e.mu.RUnlock() - return e.state + return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) } // Close cleans the endpoint's resources and leaves the endpoint in a closed @@ -131,7 +167,7 @@ func (e *Endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.state == transport.DatagramEndpointStateClosed { + if e.State() == transport.DatagramEndpointStateClosed { return } @@ -145,7 +181,7 @@ func (e *Endpoint) Close() { e.connectedRoute = nil } - e.state = transport.DatagramEndpointStateClosed + e.setEndpointState(transport.DatagramEndpointStateClosed) } // SetOwner sets the owner of transmitted packets. @@ -226,7 +262,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrInvalidOptionValue{} } - if e.state == transport.DatagramEndpointStateClosed { + if e.State() == transport.DatagramEndpointStateClosed { return WriteContext{}, &tcpip.ErrInvalidEndpointState{} } @@ -238,7 +274,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if e.state != transport.DatagramEndpointStateConnected { + if e.State() != transport.DatagramEndpointStateConnected { return WriteContext{}, &tcpip.ErrDestinationRequired{} } @@ -250,18 +286,19 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } - if e.info.BindNICID != 0 { - if nicID != 0 && nicID != e.info.BindNICID { + info := e.Info() + if info.BindNICID != 0 { + if nicID != 0 && nicID != info.BindNICID { return WriteContext{}, &tcpip.ErrNoRoute{} } - nicID = e.info.BindNICID + nicID = info.BindNICID } if nicID == 0 { - nicID = e.info.RegisterNICID + nicID = info.RegisterNICID } - dst, netProto, err := e.checkV4MappedRLocked(*opts.To) + dst, netProto, err := e.checkV4Mapped(*opts.To) if err != nil { return WriteContext{}, err } @@ -301,20 +338,22 @@ func (e *Endpoint) Disconnect() { e.mu.Lock() defer e.mu.Unlock() - if e.state != transport.DatagramEndpointStateConnected { + if e.State() != transport.DatagramEndpointStateConnected { return } + info := e.Info() // Exclude ephemerally bound endpoints. if e.wasBound { - e.info.ID = stack.TransportEndpointID{ - LocalAddress: e.info.BindAddr, + info.ID = stack.TransportEndpointID{ + LocalAddress: info.BindAddr, } - e.state = transport.DatagramEndpointStateBound + e.setEndpointState(transport.DatagramEndpointStateBound) } else { - e.info.ID = stack.TransportEndpointID{} - e.state = transport.DatagramEndpointStateInitial + info.ID = stack.TransportEndpointID{} + e.setEndpointState(transport.DatagramEndpointStateInitial) } + e.setInfo(info) e.connectedRoute.Release() e.connectedRoute = nil @@ -327,7 +366,7 @@ func (e *Endpoint) Disconnect() { // TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. // +checklocks:e.mu func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.info.ID.LocalAddress + localAddr := e.Info().ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -370,24 +409,25 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.mu.Lock() defer e.mu.Unlock() + info := e.Info() nicID := addr.NIC - switch e.state { + switch e.State() { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: - if e.info.BindNICID == 0 { + if info.BindNICID == 0 { break } - if nicID != 0 && nicID != e.info.BindNICID { + if nicID != 0 && nicID != info.BindNICID { return &tcpip.ErrInvalidEndpointState{} } - nicID = e.info.BindNICID + nicID = info.BindNICID default: return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedRLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } @@ -398,14 +438,14 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. } id := stack.TransportEndpointID{ - LocalAddress: e.info.ID.LocalAddress, + LocalAddress: info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } - if e.state == transport.DatagramEndpointStateInitial { + if e.State() == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } - if err := f(r.NetProto(), e.info.ID, id); err != nil { + if err := f(r.NetProto(), info.ID, id); err != nil { return err } @@ -414,10 +454,11 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.connectedRoute.Release() } e.connectedRoute = r - e.info.ID = id - e.info.RegisterNICID = nicID + info.ID = id + info.RegisterNICID = nicID + e.setInfo(info) e.effectiveNetProto = netProto - e.state = transport.DatagramEndpointStateConnected + e.setEndpointState(transport.DatagramEndpointStateConnected) return nil } @@ -426,7 +467,7 @@ func (e *Endpoint) Shutdown() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - switch state := e.state; state { + switch state := e.State(); state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: @@ -439,11 +480,9 @@ func (e *Endpoint) Shutdown() tcpip.Error { // checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -// -// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. -// +checklocks:e.mu -func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) +func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + info := e.Info() + unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -474,11 +513,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != transport.DatagramEndpointStateInitial { + if e.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedRLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } @@ -497,14 +536,16 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto e.wasBound = true - e.info.ID = stack.TransportEndpointID{ + info := e.Info() + info.ID = stack.TransportEndpointID{ LocalAddress: addr.Addr, } - e.info.BindNICID = addr.NIC - e.info.RegisterNICID = nicID - e.info.BindAddr = addr.Addr + info.BindNICID = addr.NIC + info.RegisterNICID = nicID + info.BindAddr = addr.Addr + e.setInfo(info) e.effectiveNetProto = netProto - e.state = transport.DatagramEndpointStateBound + e.setEndpointState(transport.DatagramEndpointStateBound) return nil } @@ -520,13 +561,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { e.mu.RLock() defer e.mu.RUnlock() - addr := e.info.BindAddr - if e.state == transport.DatagramEndpointStateConnected { + info := e.Info() + addr := info.BindAddr + if e.State() == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } return tcpip.FullAddress{ - NIC: e.info.RegisterNICID, + NIC: info.RegisterNICID, Addr: addr, } } @@ -536,13 +578,13 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != transport.DatagramEndpointStateConnected { + if e.State() != transport.DatagramEndpointStateConnected { return tcpip.FullAddress{}, false } return tcpip.FullAddress{ Addr: e.connectedRoute.RemoteAddress(), - NIC: e.info.RegisterNICID, + NIC: e.Info().RegisterNICID, }, true } @@ -624,7 +666,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedRLocked(fa) + fa, netProto, err := e.checkV4Mapped(fa) if err != nil { return err } @@ -648,7 +690,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } - if e.info.BindNICID != 0 && e.info.BindNICID != nic { + if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic { return &tcpip.ErrInvalidEndpointState{} } @@ -751,7 +793,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { // Info returns a copy of the endpoint info. func (e *Endpoint) Info() stack.TransportEndpointInfo { - e.mu.RLock() - defer e.mu.RUnlock() + e.infoMu.RLock() + defer e.infoMu.RUnlock() return e.info } + +// setInfo sets the endpoint's info. +// +// e.mu must be held to synchronize changes to info with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) { + e.infoMu.Lock() + defer e.infoMu.Unlock() + e.info = info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 173197512..68bd1fbf6 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,22 +35,24 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } - switch e.state { + info := e.Info() + + switch state := e.State(); state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: - if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { - if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 { - panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress)) + if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress)) } } case transport.DatagramEndpointStateConnected: var err tcpip.Error multicastLoop := e.ops.GetMulticastLoop() - e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) if err != nil { - panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: - panic(fmt.Sprintf("unhandled state = %s", e.state)) + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/internal/network/network_state_autogen.go b/pkg/tcpip/transport/internal/network/network_state_autogen.go index f72149c1c..1515c8632 100644 --- a/pkg/tcpip/transport/internal/network/network_state_autogen.go +++ b/pkg/tcpip/transport/internal/network/network_state_autogen.go @@ -15,9 +15,7 @@ func (e *Endpoint) StateFields() []string { "ops", "netProto", "transProto", - "state", "wasBound", - "info", "owner", "writeShutdown", "effectiveNetProto", @@ -28,6 +26,8 @@ func (e *Endpoint) StateFields() []string { "multicastNICID", "ipv4TOS", "ipv6TClass", + "info", + "state", } } @@ -39,19 +39,19 @@ func (e *Endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(0, &e.ops) stateSinkObject.Save(1, &e.netProto) stateSinkObject.Save(2, &e.transProto) - stateSinkObject.Save(3, &e.state) - stateSinkObject.Save(4, &e.wasBound) - stateSinkObject.Save(5, &e.info) - stateSinkObject.Save(6, &e.owner) - stateSinkObject.Save(7, &e.writeShutdown) - stateSinkObject.Save(8, &e.effectiveNetProto) - stateSinkObject.Save(9, &e.multicastMemberships) - stateSinkObject.Save(10, &e.ttl) - stateSinkObject.Save(11, &e.multicastTTL) - stateSinkObject.Save(12, &e.multicastAddr) - stateSinkObject.Save(13, &e.multicastNICID) - stateSinkObject.Save(14, &e.ipv4TOS) - stateSinkObject.Save(15, &e.ipv6TClass) + stateSinkObject.Save(3, &e.wasBound) + stateSinkObject.Save(4, &e.owner) + stateSinkObject.Save(5, &e.writeShutdown) + stateSinkObject.Save(6, &e.effectiveNetProto) + stateSinkObject.Save(7, &e.multicastMemberships) + stateSinkObject.Save(8, &e.ttl) + stateSinkObject.Save(9, &e.multicastTTL) + stateSinkObject.Save(10, &e.multicastAddr) + stateSinkObject.Save(11, &e.multicastNICID) + stateSinkObject.Save(12, &e.ipv4TOS) + stateSinkObject.Save(13, &e.ipv6TClass) + stateSinkObject.Save(14, &e.info) + stateSinkObject.Save(15, &e.state) } func (e *Endpoint) afterLoad() {} @@ -61,19 +61,19 @@ func (e *Endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &e.ops) stateSourceObject.Load(1, &e.netProto) stateSourceObject.Load(2, &e.transProto) - stateSourceObject.Load(3, &e.state) - stateSourceObject.Load(4, &e.wasBound) - stateSourceObject.Load(5, &e.info) - stateSourceObject.Load(6, &e.owner) - stateSourceObject.Load(7, &e.writeShutdown) - stateSourceObject.Load(8, &e.effectiveNetProto) - stateSourceObject.Load(9, &e.multicastMemberships) - stateSourceObject.Load(10, &e.ttl) - stateSourceObject.Load(11, &e.multicastTTL) - stateSourceObject.Load(12, &e.multicastAddr) - stateSourceObject.Load(13, &e.multicastNICID) - stateSourceObject.Load(14, &e.ipv4TOS) - stateSourceObject.Load(15, &e.ipv6TClass) + stateSourceObject.Load(3, &e.wasBound) + stateSourceObject.Load(4, &e.owner) + stateSourceObject.Load(5, &e.writeShutdown) + stateSourceObject.Load(6, &e.effectiveNetProto) + stateSourceObject.Load(7, &e.multicastMemberships) + stateSourceObject.Load(8, &e.ttl) + stateSourceObject.Load(9, &e.multicastTTL) + stateSourceObject.Load(10, &e.multicastAddr) + stateSourceObject.Load(11, &e.multicastNICID) + stateSourceObject.Load(12, &e.ipv4TOS) + stateSourceObject.Load(13, &e.ipv6TClass) + stateSourceObject.Load(14, &e.info) + stateSourceObject.Load(15, &e.state) } func (m *multicastMembership) StateTypeName() string { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 2c9786175..1f30e5adb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -59,13 +59,11 @@ type packet struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool ops tcpip.SocketOptions @@ -84,6 +82,8 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu + netProto tcpip.NetworkProtocolNumber + // +checklocks:mu closed bool // +checklocks:mu bound bool @@ -98,10 +98,7 @@ type endpoint struct { // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - }, + stack: s, cooked: cooked, netProto: netProto, waiterQueue: waiterQueue, @@ -214,13 +211,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC + proto := ep.netProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} } var remote tcpip.LinkAddress - proto := ep.netProto if to := opts.To; to != nil { remote = tcpip.LinkAddress(to.Addr) @@ -296,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound && ep.boundNIC == addr.NIC { + netProto := tcpip.NetworkProtocolNumber(addr.Port) + if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto { // If the NIC being bound is the same then just return success. return nil } @@ -306,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.bound = false // Bind endpoint to receive packets from specific interface. - if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } ep.bound = true ep.boundNIC = addr.NIC + ep.netProto = netProto return nil } @@ -473,10 +472,8 @@ func (*endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() - // Make a copy of the endpoint info. - ret := ep.TransportEndpointInfo - ep.mu.RUnlock() - return &ret + defer ep.mu.RUnlock() + return &stack.TransportEndpointInfo{NetProto: ep.netProto} } // Stats returns a pointer to the endpoint stats. diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go index 6017e9e8b..75e6a2070 100644 --- a/pkg/tcpip/transport/packet/packet_state_autogen.go +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -54,9 +54,7 @@ func (ep *endpoint) StateTypeName() string { func (ep *endpoint) StateFields() []string { return []string{ - "TransportEndpointInfo", "DefaultSocketOptionsHandler", - "netProto", "waiterQueue", "cooked", "ops", @@ -64,6 +62,7 @@ func (ep *endpoint) StateFields() []string { "rcvBufSize", "rcvClosed", "rcvDisabled", + "netProto", "closed", "bound", "boundNIC", @@ -74,38 +73,36 @@ func (ep *endpoint) StateFields() []string { // +checklocksignore func (ep *endpoint) StateSave(stateSinkObject state.Sink) { ep.beforeSave() - stateSinkObject.Save(0, &ep.TransportEndpointInfo) - stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler) - stateSinkObject.Save(2, &ep.netProto) - stateSinkObject.Save(3, &ep.waiterQueue) - stateSinkObject.Save(4, &ep.cooked) - stateSinkObject.Save(5, &ep.ops) - stateSinkObject.Save(6, &ep.rcvList) - stateSinkObject.Save(7, &ep.rcvBufSize) - stateSinkObject.Save(8, &ep.rcvClosed) - stateSinkObject.Save(9, &ep.rcvDisabled) - stateSinkObject.Save(10, &ep.closed) - stateSinkObject.Save(11, &ep.bound) - stateSinkObject.Save(12, &ep.boundNIC) - stateSinkObject.Save(13, &ep.lastError) + stateSinkObject.Save(0, &ep.DefaultSocketOptionsHandler) + stateSinkObject.Save(1, &ep.waiterQueue) + stateSinkObject.Save(2, &ep.cooked) + stateSinkObject.Save(3, &ep.ops) + stateSinkObject.Save(4, &ep.rcvList) + stateSinkObject.Save(5, &ep.rcvBufSize) + stateSinkObject.Save(6, &ep.rcvClosed) + stateSinkObject.Save(7, &ep.rcvDisabled) + stateSinkObject.Save(8, &ep.netProto) + stateSinkObject.Save(9, &ep.closed) + stateSinkObject.Save(10, &ep.bound) + stateSinkObject.Save(11, &ep.boundNIC) + stateSinkObject.Save(12, &ep.lastError) } // +checklocksignore func (ep *endpoint) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &ep.TransportEndpointInfo) - stateSourceObject.Load(1, &ep.DefaultSocketOptionsHandler) - stateSourceObject.Load(2, &ep.netProto) - stateSourceObject.Load(3, &ep.waiterQueue) - stateSourceObject.Load(4, &ep.cooked) - stateSourceObject.Load(5, &ep.ops) - stateSourceObject.Load(6, &ep.rcvList) - stateSourceObject.Load(7, &ep.rcvBufSize) - stateSourceObject.Load(8, &ep.rcvClosed) - stateSourceObject.Load(9, &ep.rcvDisabled) - stateSourceObject.Load(10, &ep.closed) - stateSourceObject.Load(11, &ep.bound) - stateSourceObject.Load(12, &ep.boundNIC) - stateSourceObject.Load(13, &ep.lastError) + stateSourceObject.Load(0, &ep.DefaultSocketOptionsHandler) + stateSourceObject.Load(1, &ep.waiterQueue) + stateSourceObject.Load(2, &ep.cooked) + stateSourceObject.Load(3, &ep.ops) + stateSourceObject.Load(4, &ep.rcvList) + stateSourceObject.Load(5, &ep.rcvBufSize) + stateSourceObject.Load(6, &ep.rcvClosed) + stateSourceObject.Load(7, &ep.rcvDisabled) + stateSourceObject.Load(8, &ep.netProto) + stateSourceObject.Load(9, &ep.closed) + stateSourceObject.Load(10, &ep.bound) + stateSourceObject.Load(11, &ep.boundNIC) + stateSourceObject.Load(12, &ep.lastError) stateSourceObject.AfterLoad(ep.afterLoad) } |