diff options
Diffstat (limited to 'pkg/tcpip/transport/packet/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 40 |
1 files changed, 21 insertions, 19 deletions
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 1f30e5adb..d669fe55e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -82,11 +82,9 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu - netProto tcpip.NetworkProtocolNumber - // +checklocks:mu closed bool // +checklocks:mu - bound bool + boundNetProto tcpip.NetworkProtocolNumber // +checklocks:mu boundNIC tcpip.NICID @@ -98,10 +96,10 @@ type endpoint struct { // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, + stack: s, + cooked: cooked, + boundNetProto: netProto, + waiterQueue: waiterQueue, } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) @@ -137,7 +135,7 @@ func (ep *endpoint) Close() { return } - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) ep.rcvMu.Lock() defer ep.rcvMu.Unlock() @@ -150,7 +148,6 @@ func (ep *endpoint) Close() { } ep.closed = true - ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -211,7 +208,7 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC - proto := ep.netProto + proto := ep.boundNetProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} @@ -294,24 +291,29 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { defer ep.mu.Unlock() netProto := tcpip.NetworkProtocolNumber(addr.Port) - if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto { - // If the NIC being bound is the same then just return success. + if netProto == 0 { + // Do not allow unbinding the network protocol. + netProto = ep.boundNetProto + } + + if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto { + // Already bound to the requested NIC and network protocol. return nil } - // Unregister endpoint with all the nics. - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) - ep.bound = false + // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new + // binding. + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) + ep.boundNIC = 0 + ep.boundNetProto = 0 // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } - ep.bound = true ep.boundNIC = addr.NIC - ep.netProto = netProto - + ep.boundNetProto = netProto return nil } @@ -473,7 +475,7 @@ func (*endpoint) State() uint32 { func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() defer ep.mu.RUnlock() - return &stack.TransportEndpointInfo{NetProto: ep.netProto} + return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto} } // Stats returns a pointer to the endpoint stats. |