summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go40
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go6
2 files changed, 24 insertions, 22 deletions
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 1f30e5adb..d669fe55e 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -82,11 +82,9 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// +checklocks:mu
- netProto tcpip.NetworkProtocolNumber
- // +checklocks:mu
closed bool
// +checklocks:mu
- bound bool
+ boundNetProto tcpip.NetworkProtocolNumber
// +checklocks:mu
boundNIC tcpip.NICID
@@ -98,10 +96,10 @@ type endpoint struct {
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: s,
+ cooked: cooked,
+ boundNetProto: netProto,
+ waiterQueue: waiterQueue,
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
@@ -137,7 +135,7 @@ func (ep *endpoint) Close() {
return
}
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
ep.rcvMu.Lock()
defer ep.rcvMu.Unlock()
@@ -150,7 +148,6 @@ func (ep *endpoint) Close() {
}
ep.closed = true
- ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -211,7 +208,7 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
- proto := ep.netProto
+ proto := ep.boundNetProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
@@ -294,24 +291,29 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
defer ep.mu.Unlock()
netProto := tcpip.NetworkProtocolNumber(addr.Port)
- if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto {
- // If the NIC being bound is the same then just return success.
+ if netProto == 0 {
+ // Do not allow unbinding the network protocol.
+ netProto = ep.boundNetProto
+ }
+
+ if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
+ // Already bound to the requested NIC and network protocol.
return nil
}
- // Unregister endpoint with all the nics.
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
- ep.bound = false
+ // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
+ // binding.
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
+ ep.boundNIC = 0
+ ep.boundNetProto = 0
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
- ep.bound = true
ep.boundNIC = addr.NIC
- ep.netProto = netProto
-
+ ep.boundNetProto = netProto
return nil
}
@@ -473,7 +475,7 @@ func (*endpoint) State() uint32 {
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
defer ep.mu.RUnlock()
- return &stack.TransportEndpointInfo{NetProto: ep.netProto}
+ return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
}
// Stats returns a pointer to the endpoint stats.
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index d2768db7b..88cd80ad3 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -57,9 +58,8 @@ func (ep *endpoint) afterLoad() {
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
- if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(err)
+ if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
+ panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
ep.rcvMu.Lock()