From 5951ec5bce17e7696d2fd53ce384839555dd3c79 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 20 Sep 2021 12:14:29 -0700 Subject: Do not allow unbinding network protocol Once a packet socket is bound to a network protocol, it cannot be unbound from that protocol; the network protocol binding may only be updated to a different network protocol. To comply with Linux. PiperOrigin-RevId: 397810878 --- pkg/tcpip/transport/packet/endpoint.go | 40 +++++++++++++++------------- pkg/tcpip/transport/packet/endpoint_state.go | 6 ++--- test/syscalls/linux/packet_socket.cc | 5 ++++ 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 1f30e5adb..d669fe55e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -82,11 +82,9 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu - netProto tcpip.NetworkProtocolNumber - // +checklocks:mu closed bool // +checklocks:mu - bound bool + boundNetProto tcpip.NetworkProtocolNumber // +checklocks:mu boundNIC tcpip.NICID @@ -98,10 +96,10 @@ type endpoint struct { // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, + stack: s, + cooked: cooked, + boundNetProto: netProto, + waiterQueue: waiterQueue, } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) @@ -137,7 +135,7 @@ func (ep *endpoint) Close() { return } - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) ep.rcvMu.Lock() defer ep.rcvMu.Unlock() @@ -150,7 +148,6 @@ func (ep *endpoint) Close() { } ep.closed = true - ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -211,7 +208,7 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC - proto := ep.netProto + proto := ep.boundNetProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} @@ -294,24 +291,29 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { defer ep.mu.Unlock() netProto := tcpip.NetworkProtocolNumber(addr.Port) - if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto { - // If the NIC being bound is the same then just return success. + if netProto == 0 { + // Do not allow unbinding the network protocol. + netProto = ep.boundNetProto + } + + if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto { + // Already bound to the requested NIC and network protocol. return nil } - // Unregister endpoint with all the nics. - ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) - ep.bound = false + // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new + // binding. + ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep) + ep.boundNIC = 0 + ep.boundNetProto = 0 // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } - ep.bound = true ep.boundNIC = addr.NIC - ep.netProto = netProto - + ep.boundNetProto = netProto return nil } @@ -473,7 +475,7 @@ func (*endpoint) State() uint32 { func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() defer ep.mu.RUnlock() - return &stack.TransportEndpointInfo{NetProto: ep.netProto} + return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto} } // Stats returns a pointer to the endpoint stats. diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index d2768db7b..88cd80ad3 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -57,9 +58,8 @@ func (ep *endpoint) afterLoad() { ep.stack = stack.StackFromEnv ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. - if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(err) + if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil { + panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err)) } ep.rcvMu.Lock() diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 43828a52e..81e607a3f 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -214,6 +214,11 @@ TEST_P(PacketSocketTest, RebindProtocol) { ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); ASSERT_NO_FATAL_FAILURE(test_recv(counter)); + + // A zero valued protocol number should not change the bound network protocol. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(0)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); } INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, -- cgit v1.2.3