From cbf39804647eabafb6138714ed222dbdc4781f2e Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 12 Sep 2018 09:37:57 -0700 Subject: Prevent UDP sockets from binding to bound ports PiperOrigin-RevId: 212653818 Change-Id: Ib4e1d754d9cdddeaa428a066cb675e6ec44d91ad --- pkg/tcpip/transport/udp/endpoint.go | 35 ++++++--------- pkg/tcpip/transport/udp/endpoint_state.go | 7 ++- pkg/tcpip/transport/udp/udp_test.go | 74 ++++++++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 24 deletions(-) (limited to 'pkg/tcpip/transport') diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 283379a28..d091a6196 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -132,6 +132,7 @@ func (e *endpoint) Close() { switch e.state { case stateBound, stateConnected: e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } // Close the receive list and drain it. @@ -496,7 +497,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() nicid := addr.NIC - localPort := uint16(0) + var localPort uint16 switch e.state { case stateInitial: case stateBound, stateConnected: @@ -537,7 +538,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if e.netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.v6only { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -611,27 +612,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if id.LocalPort != 0 { - // The endpoint already has a local port, just attempt to - // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) - return id, err - } - - // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) - switch err { - case nil: - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: - return false, err + if e.id.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + if err != nil { + return id, err } - }) + id.LocalPort = port + } + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + } return id, err } @@ -677,6 +669,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error if err := commit(); err != nil { // Unregister, the commit failed. e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) return err } } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 30c16682b..70a37c7f2 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -94,7 +94,12 @@ func (e *endpoint) afterLoad() { } } - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, e.id) + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.id + e.id.LocalPort = 0 + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) if err != nil { panic(*err) } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c1c099900..4700193c2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -112,7 +112,7 @@ func (c *testContext) cleanup() { } } -func (c *testContext) createV6Endpoint(v4only bool) { +func (c *testContext) createV6Endpoint(v6only bool) { var err *tcpip.Error c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) if err != nil { @@ -120,7 +120,7 @@ func (c *testContext) createV6Endpoint(v4only bool) { } var v tcpip.V6OnlyOption - if v4only { + if v6only { v = 1 } if err := c.ep.SetSockOpt(v); err != nil { @@ -296,6 +296,76 @@ func testV4Read(c *testContext) { } } +func TestBindEphemeralPort(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + if err := c.ep.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } +} + +func TestBindReservedPort(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + addr, err := c.ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %v", err) + } + + // We can't bind the address reserved by the connected endpoint above. + { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + if got, want := ep.Bind(addr, nil), tcpip.ErrPortInUse; got != want { + t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + } + } + + func() { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + // We can't bind ipv4-any on the port reserved by the connected endpoint + // above, since the endpoint is dual-stack. + if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil), tcpip.ErrPortInUse; got != want { + t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + } + // We can bind an ipv4 address on this port, though. + if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + }() + + // Once the connected endpoint releases its port reservation, we are able to + // bind ipv4-any once again. + c.ep.Close() + func() { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + }() +} + func TestV4ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() -- cgit v1.2.3