summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2021-04-16 16:26:31 -0700
committergVisor bot <gvisor-bot@google.com>2021-04-16 16:28:56 -0700
commit32c18f443f567dac21465b3999d1a18b886891d1 (patch)
tree6be713fc4cb0481d526d212c084b328e0bc671ea /pkg
parent6241f89f49820c193213b2d395bb09030409166d (diff)
Enlarge port range and fix integer overflow
Also count failed TCP port allocations PiperOrigin-RevId: 368939619
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go1
-rw-r--r--pkg/tcpip/ports/ports.go26
-rw-r--r--pkg/tcpip/ports/ports_test.go29
-rw-r--r--pkg/tcpip/tcpip.go4
-rw-r--r--pkg/tcpip/transport/tcp/accept.go1
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
6 files changed, 51 insertions, 14 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index ed6572bab..b5ca3a693 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -242,6 +242,7 @@ var Metrics = tcpip.Stats{
FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
+ FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 678199371..b5b013b64 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -17,6 +17,7 @@
package ports
import (
+ "math"
"math/rand"
"sync/atomic"
@@ -24,7 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-const anyIPAddress tcpip.Address = ""
+const (
+ firstEphemeral = 16000
+ anyIPAddress tcpip.Address = ""
+)
// Reservation describes a port reservation.
type Reservation struct {
@@ -220,10 +224,8 @@ type PortManager struct {
func NewPortManager() *PortManager {
return &PortManager{
allocatedPorts: make(map[portDescriptor]addrToDevice),
- // Match Linux's default ephemeral range. See:
- // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842
- firstEphemeral: 32768,
- numEphemeral: 28232,
+ firstEphemeral: firstEphemeral,
+ numEphemeral: math.MaxUint16 - firstEphemeral + 1,
}
}
@@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint16(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rand.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
// portHint atomically reads and returns the pm.hint value.
-func (pm *PortManager) portHint() uint16 {
- return uint16(atomic.LoadUint32(&pm.hint))
+func (pm *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&pm.hint)
}
// incPortHint atomically increments pm.hint by 1.
@@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() {
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
@@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
- for i := uint16(0); i < count; i++ {
- port = first + (offset+i)%count
+func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ for i := uint32(0); i < uint32(count); i++ {
+ port := uint16(uint32(first) + (offset+i)%uint32(count))
ok, err := testPort(port)
if err != nil {
return 0, err
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 5dfc5371a..6c4fb8c68 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,6 +15,7 @@
package ports
import (
+ "math"
"math/rand"
"testing"
@@ -482,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) {
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err)
}
- portOffset := uint16(rand.Int31n(int32(numEphemeralPorts)))
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
port, err := pm.PickEphemeralPortStable(portOffset, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
@@ -493,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) {
})
}
}
+
+// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a
+// port allocation failure.
+func TestOverflow(t *testing.T) {
+ // Use a small range and start at offsets that will cause an overflow.
+ count := uint16(50)
+ for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ {
+ reservedPorts := make(map[uint16]struct{})
+ // Ensure we can reserve everything in the allowed range.
+ for i := uint16(0); i < count; i++ {
+ port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) {
+ if _, ok := reservedPorts[port]; !ok {
+ reservedPorts[port] = struct{}{}
+ return true, nil
+ }
+ return false, nil
+ })
+ if err != nil {
+ t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts))
+ }
+ if port < firstEphemeral || port > firstEphemeral+count {
+ t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index d1b5982f5..33a55d89f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1732,6 +1732,10 @@ type TCPStats struct {
// ChecksumErrors is the number of segments dropped due to bad checksums.
ChecksumErrors *StatCounter
+
+ // FailedPortReservations is the number of times TCP failed to reserve
+ // a port.
+ FailedPortReservations *StatCounter
}
// UDPStats collects UDP-specific stats.
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 63c46b1be..4be306434 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -455,6 +455,7 @@ func (e *endpoint) reserveTupleLocked() bool {
Dest: dest,
}
if !e.stack.ReserveTuple(portRes) {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return false
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9afd2bb7f..bc88e48e9 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2251,7 +2251,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
panic(err)
}
}
- portOffset := uint16(h.Sum32())
+ portOffset := h.Sum32()
var twReuse tcpip.TCPTimeWaitReuseOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
@@ -2362,6 +2362,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
e.boundDest = addr
return true, nil
}); err != nil {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return err
}
}
@@ -2685,6 +2686,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
return true, nil
})
if err != nil {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return err
}