diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint_state.go | 6 |
2 files changed, 24 insertions, 22 deletions
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 1f30e5adb..d669fe55e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -82,11 +82,9 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu - netProto tcpip.NetworkProtocolNumber - // +checklocks:mu closed bool // +checklocks:mu - bound bool + boundNetProto tcpip.NetworkProtocolNumber // +checklocks:mu boundNIC tcpip.NICID @@ -98,10 +96,10 @@ type endpoint struct { // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, + stack: s, + cooked: cooked, + boundNetProto: netProto, + waiterQueue: waiterQueue, } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) @@ -137,7 +135,7 @@ func (ep *endpoint) Close() { return } - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) ep.rcvMu.Lock() defer ep.rcvMu.Unlock() @@ -150,7 +148,6 @@ func (ep *endpoint) Close() { } ep.closed = true - ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -211,7 +208,7 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC - proto := ep.netProto + proto := ep.boundNetProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} @@ -294,24 +291,29 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { defer ep.mu.Unlock() netProto := tcpip.NetworkProtocolNumber(addr.Port) - if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto { - // If the NIC being bound is the same then just return success. + if netProto == 0 { + // Do not allow unbinding the network protocol. + netProto = ep.boundNetProto + } + + if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto { + // Already bound to the requested NIC and network protocol. return nil } - // Unregister endpoint with all the nics. - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) - ep.bound = false + // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new + // binding. + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) + ep.boundNIC = 0 + ep.boundNetProto = 0 // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } - ep.bound = true ep.boundNIC = addr.NIC - ep.netProto = netProto - + ep.boundNetProto = netProto return nil } @@ -473,7 +475,7 @@ func (*endpoint) State() uint32 { func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() defer ep.mu.RUnlock() - return &stack.TransportEndpointInfo{NetProto: ep.netProto} + return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto} } // Stats returns a pointer to the endpoint stats. diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index d2768db7b..88cd80ad3 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -57,9 +58,8 @@ func (ep *endpoint) afterLoad() { ep.stack = stack.StackFromEnv ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. - if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(err) + if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil { + panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err)) } ep.rcvMu.Lock() |