diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/ports/ports.go | 49 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports_test.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 17 |
6 files changed, 143 insertions, 6 deletions
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 40e202717..30cea8996 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -19,6 +19,7 @@ import ( "math" "math/rand" "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,6 +28,10 @@ const ( // FirstEphemeral is the first ephemeral port. FirstEphemeral = 16000 + // numEphemeralPorts it the mnumber of available ephemeral ports to + // Netstack. + numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 + anyIPAddress tcpip.Address = "" ) @@ -40,6 +45,13 @@ type portDescriptor struct { type PortManager struct { mu sync.RWMutex allocatedPorts map[portDescriptor]bindAddresses + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 } type portNode struct { @@ -130,11 +142,40 @@ func NewPortManager() *PortManager { // is suitable for its needs, and stopping when a port is found or an error // occurs. func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { - count := uint16(math.MaxUint16 - FirstEphemeral + 1) - offset := uint16(rand.Int31n(int32(count))) + offset := uint32(rand.Int31n(numEphemeralPorts)) + return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) +} + +// portHint atomically reads and returns the s.hint value. +func (s *PortManager) portHint() uint32 { + return atomic.LoadUint32(&s.hint) +} + +// incPortHint atomically increments s.hint by 1. +func (s *PortManager) incPortHint() { + atomic.AddUint32(&s.hint, 1) +} + +// PickEphemeralPortStable starts at the specified offset + s.portHint and +// iterates over all ephemeral ports, allowing the caller to decide whether a +// given port is suitable for its needs and stopping when a port is found or an +// error occurs. +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) + if err == nil { + s.incPortHint() + } + return p, err + +} - for i := uint16(0); i < count; i++ { - port = FirstEphemeral + (offset+i)%count +// pickEphemeralPort starts at the offset specified from the FirstEphemeral port +// and iterates over the number of ports specified by count and allows the +// caller to decide whether a given port is suitable for its needs, and stopping +// when a port is found or an error occurs. +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + for i := uint32(0); i < count; i++ { + port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) if err != nil { return 0, err diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index a67e283f1..19f4833fc 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,6 +15,7 @@ package ports import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -232,7 +233,6 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - pm := NewPortManager() customErr := &tcpip.Error{} for _, test := range []struct { name string @@ -276,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) } }) } } + +func TestPickEphemeralPortStable(t *testing.T) { + customErr := &tcpip.Error{} + for _, test := range []struct { + name string + f func(port uint16) (bool, *tcpip.Error) + wantErr *tcpip.Error + wantPort uint16 + }{ + { + name: "no-port-available", + f: func(port uint16) (bool, *tcpip.Error) { + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + { + name: "port-tester-error", + f: func(port uint16) (bool, *tcpip.Error) { + return false, customErr + }, + wantErr: customErr, + }, + { + name: "only-port-16042-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port == FirstEphemeral+42 { + return true, nil + } + return false, nil + }, + wantPort: FirstEphemeral + 42, + }, + { + name: "only-port-under-16000-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port < FirstEphemeral { + return true, nil + } + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + } { + t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { + t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + } + }) + } +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 3842f1f7d..baf88bfab 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( ], deps = [ "//pkg/ilist", + "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6a8079823..90c2cf1be 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,10 +20,12 @@ package stack import ( + "encoding/binary" "sync" "time" "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -388,6 +390,12 @@ type Stack struct { // icmpRateLimiter is a global rate limiter for all ICMP messages generated // by the stack. icmpRateLimiter *ICMPRateLimiter + + // portSeed is a one-time random value initialized at stack startup + // and is used to seed the TCP port picking on active connections + // + // TODO(gvisor.dev/issues/940): S/R this field. + portSeed uint32 } // Options contains optional Stack configuration. @@ -440,6 +448,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, icmpRateLimiter: NewICMPRateLimiter(), + portSeed: generateRandUint32(), } // Add specified network protocols. @@ -1197,3 +1206,19 @@ func (s *Stack) SetICMPBurst(burst int) { func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } + +// PortSeed returns a 32 bit value that can be used as a seed value for port +// picking. +// +// NOTE: The seed is generated once during stack initialization only. +func (s *Stack) PortSeed() uint32 { + return s.portSeed +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 39a839ab7..a42e1f4a2 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -49,6 +49,7 @@ go_library( "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/seqnum", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a1cd0d481..f9d5e0085 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "fmt" "math" "strings" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -1504,7 +1506,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). sameAddr := e.id.LocalAddress == e.id.RemoteAddress - if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + + // Calculate a port offset based on the destination IP/port and + // src IP to ensure that for a given tuple (srcIP, destIP, + // destPort) the offset used as a starting point is the same to + // ensure that we can cycle through the port space effectively. + h := jenkins.Sum32(e.stack.PortSeed()) + h.Write([]byte(e.id.LocalAddress)) + h.Write([]byte(e.id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, e.id.RemotePort) + h.Write(portBuf) + portOffset := h.Sum32() + + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.id.RemotePort { return false, nil } |