From d6409b6564d6f908a217709010df2276497b264b Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Mon, 17 Sep 2018 20:42:48 -0700 Subject: Prevent TCP connect from picking bound ports PiperOrigin-RevId: 213387851 Change-Id: Icc6850761bc11afd0525f34863acd77584155140 --- pkg/tcpip/ports/ports.go | 46 ++++++--- pkg/tcpip/ports/ports_test.go | 10 +- pkg/tcpip/transport/tcp/BUILD | 2 + pkg/tcpip/transport/tcp/endpoint.go | 20 ++-- pkg/tcpip/transport/tcp/tcp_test.go | 194 ++++++++++++++++++++++++++++++++++-- 5 files changed, 233 insertions(+), 39 deletions(-) (limited to 'pkg') diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index db7371efb..4e24efddb 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -24,8 +24,8 @@ import ( ) const ( - // firstEphemeral is the first ephemeral port. - firstEphemeral uint16 = 16000 + // FirstEphemeral is the first ephemeral port. + FirstEphemeral = 16000 anyIPAddress tcpip.Address = "" ) @@ -73,11 +73,11 @@ func NewPortManager() *PortManager { // is suitable for its needs, and stopping when a port is found or an error // occurs. func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { - count := uint16(math.MaxUint16 - firstEphemeral + 1) + count := uint16(math.MaxUint16 - FirstEphemeral + 1) offset := uint16(rand.Int31n(int32(count))) for i := uint16(0); i < count; i++ { - port = firstEphemeral + (offset+i)%count + port = FirstEphemeral + (offset+i)%count ok, err := testPort(port) if err != nil { return 0, err @@ -91,6 +91,25 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er return 0, tcpip.ErrNoPortAvailable } +// IsPortAvailable tests if the given port is available on all given protocols. +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.isPortAvailableLocked(networks, transport, addr, port) +} + +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { + for _, network := range networks { + desc := portDescriptor{network, transport, port} + if addrs, ok := s.allocatedPorts[desc]; ok { + if !addrs.isAvailable(addr) { + return false + } + } + } + return true +} + // ReservePort marks a port/IP combination as reserved so that it cannot be // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the @@ -116,14 +135,8 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // reserveSpecificPort tries to reserve the given port on all given protocols. func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { - // Check that the port is available on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr) { - return false - } - } + if !s.isPortAvailableLocked(networks, transport, addr, port) { + return false } // Reserve port on all network protocols. @@ -148,10 +161,11 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp for _, network := range networks { desc := portDescriptor{network, transport, port} - m := s.allocatedPorts[desc] - delete(m, addr) - if len(m) == 0 { - delete(s.allocatedPorts, desc) + if m, ok := s.allocatedPorts[desc]; ok { + delete(m, addr) + if len(m) == 0 { + delete(s.allocatedPorts, desc) + } } } } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 825d5d314..4ab6a1fa2 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -78,8 +78,8 @@ func TestPortReservation(t *testing.T) { if err != test.want { t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want) } - if test.port == 0 && (gotPort == 0 || gotPort < firstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, firstEphemeral) + if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { + t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) } } @@ -118,17 +118,17 @@ func TestPickEphemeralPort(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, *tcpip.Error) { - if port == firstEphemeral+42 { + if port == FirstEphemeral+42 { return true, nil } return false, nil }, - wantPort: firstEphemeral + 42, + wantPort: FirstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, *tcpip.Error) { - if port < firstEphemeral { + if port < FirstEphemeral { return true, nil } return false, nil diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index c7943f08e..5a77ee232 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -71,6 +71,8 @@ go_test( "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4085585b0..a048cadf8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1000,23 +1000,26 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). sameAddr := e.id.LocalAddress == e.id.RemoteAddress - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.id.RemotePort { return false, nil } + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) { + return false, nil + } - e.id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) - switch err { + id := e.id + id.LocalPort = p + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) { case nil: + e.id = id return true, nil case tcpip.ErrPortInUse: return false, nil default: return false, err } - }) - if err != nil { + }); err != nil { return err } } @@ -1217,7 +1220,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } // Bind binds the endpoint to a specific local port and optionally address. -func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) { +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() @@ -1245,7 +1248,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret } } - // Reserve the port. port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) if err != nil { return err @@ -1257,7 +1259,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret // Any failures beyond this point must remove the port registration. defer func() { - if retErr != nil { + if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) e.isPortReserved = false e.effectiveNetProtos = nil diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 871177842..ac21e565b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -28,6 +28,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/ports" "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" @@ -3013,12 +3015,11 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30) } -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) +func makeStack() (*stack.Stack, *tcpip.Error) { + s := stack.New([]string{ + ipv4.ProtocolName, + ipv6.ProtocolName, + }, []string{tcp.ProtocolName}, stack.Options{}) id := loopback.New() if testing.Verbose() { @@ -3026,11 +3027,19 @@ func TestSelfConnect(t *testing.T) { } if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + return nil, err } - if err := s.AddAddress(1, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + for _, ct := range []struct { + number tcpip.NetworkProtocolNumber + address tcpip.Address + }{ + {ipv4.ProtocolNumber, context.StackAddr}, + {ipv6.ProtocolNumber, context.StackV6Addr}, + } { + if err := s.AddAddress(1, ct.number, ct.address); err != nil { + return nil, err + } } s.SetRouteTable([]tcpip.Route{ @@ -3040,8 +3049,27 @@ func TestSelfConnect(t *testing.T) { Gateway: "", NIC: 1, }, + { + Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, }) + return s, nil +} + +func TestSelfConnect(t *testing.T) { + // This test ensures that intentional self-connects work. In particular, + // it checks that if an endpoint binds to say 127.0.0.1:1000 then + // connects to 127.0.0.1:1000, then it will be connected to itself, and + // is able to send and receive data through the same endpoint. + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { @@ -3095,6 +3123,154 @@ func TestSelfConnect(t *testing.T) { } } +func TestConnectAvoidsBoundPorts(t *testing.T) { + addressTypes := func(t *testing.T, network string) []string { + switch network { + case "ipv4": + return []string{"v4"} + case "ipv6": + return []string{"v6"} + case "dual": + return []string{"v6", "mapped"} + default: + t.Fatalf("unknown network: '%s'", network) + } + + panic("unreachable") + } + + address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { + switch addressType { + case "v4": + if isAny { + return "" + } + return context.StackAddr + case "v6": + if isAny { + return "" + } + return context.StackV6Addr + case "mapped": + if isAny { + return context.V4MappedWildcardAddr + } + return context.StackV4MappedAddr + default: + t.Fatalf("unknown address type: '%s'", addressType) + } + + panic("unreachable") + } + // This test ensures that Endpoint.Connect doesn't select already-bound ports. + networks := []string{"ipv4", "ipv6", "dual"} + for _, exhaustedNetwork := range networks { + t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { + for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { + t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { + for _, isAny := range []bool{false, true} { + t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { + for _, candidateNetwork := range networks { + t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { + for _, candidateAddressType := range addressTypes(t, candidateNetwork) { + t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + + var wq waiter.Queue + var eps []tcpip.Endpoint + defer func() { + for _, ep := range eps { + ep.Close() + } + }() + makeEP := func(network string) tcpip.Endpoint { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch network { + case "ipv4": + networkProtocolNumber = ipv4.ProtocolNumber + case "ipv6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatalf("unknown network: '%s'", network) + } + ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + eps = append(eps, ep) + switch network { + case "ipv4": + case "ipv6": + if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + } + case "dual": + if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + } + default: + t.Fatalf("unknown network: '%s'", network) + } + return ep + } + + var v4reserved, v6reserved bool + switch exhaustedAddressType { + case "v4", "mapped": + v4reserved = true + case "v6": + v6reserved = true + // Dual stack sockets bound to v6 any reserve on v4 as + // well. + if isAny { + switch exhaustedNetwork { + case "ipv6": + case "dual": + v4reserved = true + default: + t.Fatalf("unknown address type: '%s'", exhaustedNetwork) + } + } + default: + t.Fatalf("unknown address type: '%s'", exhaustedAddressType) + } + var collides bool + switch candidateAddressType { + case "v4", "mapped": + collides = v4reserved + case "v6": + collides = v6reserved + default: + t.Fatalf("unknown address type: '%s'", candidateAddressType) + } + + for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { + if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}, nil); err != nil { + t.Fatalf("Bind(%d) failed: %v", i, err) + } + } + want := tcpip.ErrConnectStarted + if collides { + want = tcpip.ErrNoPortAvailable + } + if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { + t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) + } + }) + } + }) + } + }) + } + }) + } + }) + } +} + func TestPathMTUDiscovery(t *testing.T) { // This test verifies the stack retransmits packets after it receives an // ICMP packet indicating that the path MTU has been exceeded. -- cgit v1.2.3