summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2018-09-17 20:42:48 -0700
committerShentubot <shentubot@google.com>2018-09-17 20:44:04 -0700
commitd6409b6564d6f908a217709010df2276497b264b (patch)
treeeef606f06646507d0fbfdba60d48ffc3ff46b715 /pkg
parentbb88c187c5457df14fa78e5e6b6f48cbc90fb489 (diff)
Prevent TCP connect from picking bound ports
PiperOrigin-RevId: 213387851 Change-Id: Icc6850761bc11afd0525f34863acd77584155140
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/ports/ports.go46
-rw-r--r--pkg/tcpip/ports/ports_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go20
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go194
5 files changed, 233 insertions, 39 deletions
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index db7371efb..4e24efddb 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -24,8 +24,8 @@ import (
)
const (
- // firstEphemeral is the first ephemeral port.
- firstEphemeral uint16 = 16000
+ // FirstEphemeral is the first ephemeral port.
+ FirstEphemeral = 16000
anyIPAddress tcpip.Address = ""
)
@@ -73,11 +73,11 @@ func NewPortManager() *PortManager {
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - firstEphemeral + 1)
+ count := uint16(math.MaxUint16 - FirstEphemeral + 1)
offset := uint16(rand.Int31n(int32(count)))
for i := uint16(0); i < count; i++ {
- port = firstEphemeral + (offset+i)%count
+ port = FirstEphemeral + (offset+i)%count
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -91,6 +91,25 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
return 0, tcpip.ErrNoPortAvailable
}
+// IsPortAvailable tests if the given port is available on all given protocols.
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.isPortAvailableLocked(networks, transport, addr, port)
+}
+
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ if addrs, ok := s.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(addr) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
// ReservePort marks a port/IP combination as reserved so that it cannot be
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
@@ -116,14 +135,8 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// reserveSpecificPort tries to reserve the given port on all given protocols.
func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
- // Check that the port is available on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr) {
- return false
- }
- }
+ if !s.isPortAvailableLocked(networks, transport, addr, port) {
+ return false
}
// Reserve port on all network protocols.
@@ -148,10 +161,11 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
for _, network := range networks {
desc := portDescriptor{network, transport, port}
- m := s.allocatedPorts[desc]
- delete(m, addr)
- if len(m) == 0 {
- delete(s.allocatedPorts, desc)
+ if m, ok := s.allocatedPorts[desc]; ok {
+ delete(m, addr)
+ if len(m) == 0 {
+ delete(s.allocatedPorts, desc)
+ }
}
}
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 825d5d314..4ab6a1fa2 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -78,8 +78,8 @@ func TestPortReservation(t *testing.T) {
if err != test.want {
t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want)
}
- if test.port == 0 && (gotPort == 0 || gotPort < firstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, firstEphemeral)
+ if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
+ t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
}
}
@@ -118,17 +118,17 @@ func TestPickEphemeralPort(t *testing.T) {
{
name: "only-port-16042-available",
f: func(port uint16) (bool, *tcpip.Error) {
- if port == firstEphemeral+42 {
+ if port == FirstEphemeral+42 {
return true, nil
}
return false, nil
},
- wantPort: firstEphemeral + 42,
+ wantPort: FirstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, *tcpip.Error) {
- if port < firstEphemeral {
+ if port < FirstEphemeral {
return true, nil
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index c7943f08e..5a77ee232 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -71,6 +71,8 @@ go_test(
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4085585b0..a048cadf8 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1000,23 +1000,26 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.id.RemotePort {
return false, nil
}
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) {
+ return false, nil
+ }
- e.id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
- switch err {
+ id := e.id
+ id.LocalPort = p
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) {
case nil:
+ e.id = id
return true, nil
case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
}
- })
- if err != nil {
+ }); err != nil {
return err
}
}
@@ -1217,7 +1220,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
// Bind binds the endpoint to a specific local port and optionally address.
-func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) {
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1245,7 +1248,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret
}
}
- // Reserve the port.
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
if err != nil {
return err
@@ -1257,7 +1259,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret
// Any failures beyond this point must remove the port registration.
defer func() {
- if retErr != nil {
+ if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
e.isPortReserved = false
e.effectiveNetProtos = nil
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 871177842..ac21e565b 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -28,6 +28,8 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback"
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/ports"
"gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
@@ -3013,12 +3015,11 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30)
}
-func TestSelfConnect(t *testing.T) {
- // This test ensures that intentional self-connects work. In particular,
- // it checks that if an endpoint binds to say 127.0.0.1:1000 then
- // connects to 127.0.0.1:1000, then it will be connected to itself, and
- // is able to send and receive data through the same endpoint.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+func makeStack() (*stack.Stack, *tcpip.Error) {
+ s := stack.New([]string{
+ ipv4.ProtocolName,
+ ipv6.ProtocolName,
+ }, []string{tcp.ProtocolName}, stack.Options{})
id := loopback.New()
if testing.Verbose() {
@@ -3026,11 +3027,19 @@ func TestSelfConnect(t *testing.T) {
}
if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ return nil, err
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ for _, ct := range []struct {
+ number tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ }{
+ {ipv4.ProtocolNumber, context.StackAddr},
+ {ipv6.ProtocolNumber, context.StackV6Addr},
+ } {
+ if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ return nil, err
+ }
}
s.SetRouteTable([]tcpip.Route{
@@ -3040,8 +3049,27 @@ func TestSelfConnect(t *testing.T) {
Gateway: "",
NIC: 1,
},
+ {
+ Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
})
+ return s, nil
+}
+
+func TestSelfConnect(t *testing.T) {
+ // This test ensures that intentional self-connects work. In particular,
+ // it checks that if an endpoint binds to say 127.0.0.1:1000 then
+ // connects to 127.0.0.1:1000, then it will be connected to itself, and
+ // is able to send and receive data through the same endpoint.
+ s, err := makeStack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
@@ -3095,6 +3123,154 @@ func TestSelfConnect(t *testing.T) {
}
}
+func TestConnectAvoidsBoundPorts(t *testing.T) {
+ addressTypes := func(t *testing.T, network string) []string {
+ switch network {
+ case "ipv4":
+ return []string{"v4"}
+ case "ipv6":
+ return []string{"v6"}
+ case "dual":
+ return []string{"v6", "mapped"}
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+
+ panic("unreachable")
+ }
+
+ address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
+ switch addressType {
+ case "v4":
+ if isAny {
+ return ""
+ }
+ return context.StackAddr
+ case "v6":
+ if isAny {
+ return ""
+ }
+ return context.StackV6Addr
+ case "mapped":
+ if isAny {
+ return context.V4MappedWildcardAddr
+ }
+ return context.StackV4MappedAddr
+ default:
+ t.Fatalf("unknown address type: '%s'", addressType)
+ }
+
+ panic("unreachable")
+ }
+ // This test ensures that Endpoint.Connect doesn't select already-bound ports.
+ networks := []string{"ipv4", "ipv6", "dual"}
+ for _, exhaustedNetwork := range networks {
+ t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
+ for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
+ t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
+ for _, isAny := range []bool{false, true} {
+ t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
+ for _, candidateNetwork := range networks {
+ t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
+ for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
+ t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
+ s, err := makeStack()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var wq waiter.Queue
+ var eps []tcpip.Endpoint
+ defer func() {
+ for _, ep := range eps {
+ ep.Close()
+ }
+ }()
+ makeEP := func(network string) tcpip.Endpoint {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch network {
+ case "ipv4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "ipv6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps = append(eps, ep)
+ switch network {
+ case "ipv4":
+ case "ipv6":
+ if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err)
+ }
+ case "dual":
+ if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err)
+ }
+ default:
+ t.Fatalf("unknown network: '%s'", network)
+ }
+ return ep
+ }
+
+ var v4reserved, v6reserved bool
+ switch exhaustedAddressType {
+ case "v4", "mapped":
+ v4reserved = true
+ case "v6":
+ v6reserved = true
+ // Dual stack sockets bound to v6 any reserve on v4 as
+ // well.
+ if isAny {
+ switch exhaustedNetwork {
+ case "ipv6":
+ case "dual":
+ v4reserved = true
+ default:
+ t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
+ }
+ }
+ default:
+ t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
+ }
+ var collides bool
+ switch candidateAddressType {
+ case "v4", "mapped":
+ collides = v4reserved
+ case "v6":
+ collides = v6reserved
+ default:
+ t.Fatalf("unknown address type: '%s'", candidateAddressType)
+ }
+
+ for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
+ if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}, nil); err != nil {
+ t.Fatalf("Bind(%d) failed: %v", i, err)
+ }
+ }
+ want := tcpip.ErrConnectStarted
+ if collides {
+ want = tcpip.ErrNoPortAvailable
+ }
+ if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
+ t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
func TestPathMTUDiscovery(t *testing.T) {
// This test verifies the stack retransmits packets after it receives an
// ICMP packet indicating that the path MTU has been exceeded.