diff options
-rw-r--r-- | pkg/sentry/fs/proc/sys_net.go | 121 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/proc/tasks_sys.go | 78 | ||||
-rw-r--r-- | pkg/sentry/inet/inet.go | 8 | ||||
-rw-r--r-- | pkg/sentry/inet/test_stack.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/stack.go | 11 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 10 | ||||
-rw-r--r-- | pkg/syserr/netstack.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/errors.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports_test.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 4 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 6 | ||||
-rw-r--r-- | test/syscalls/linux/proc_net.cc | 37 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic_stress.cc | 136 |
17 files changed, 497 insertions, 68 deletions
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 52061175f..bbe282c03 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,6 +17,7 @@ package proc import ( "fmt" "io" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -498,6 +500,120 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) } +// portRangeInode implements fs.InodeOperations. It provides and allows +// modification of the range of ephemeral ports that IPv4 and IPv6 sockets +// choose from. +// +// +stateify savable +type portRangeInode struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + + // start and end store the port range. We must save/restore this here, + // since a netstack instance is created on restore. + start *uint16 + end *uint16 +} + +func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &portRangeInode{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*portRangeInode) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// +stateify savable +type portRangeFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + inode *portRangeInode +} + +// GetFile implements fs.InodeOperations.GetFile. +func (in *portRangeInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &portRangeFile{ + inode: in, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (pf *portRangeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + if pf.inode.start == nil { + start, end := pf.inode.stack.PortRange() + pf.inode.start = &start + pf.inode.end = &end + } + + contents := fmt.Sprintf("%d %d\n", *pf.inode.start, *pf.inode.end) + n, err := dst.CopyOut(ctx, []byte(contents)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input for performance + // reasons. + src = src.TakeFirst(usermem.PageSize - 1) + + ports := make([]int32, 2) + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) + if err != nil { + return 0, err + } + + // Port numbers must be uint16s. + if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { + return 0, syserror.EINVAL + } + + if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { + return 0, err + } + if pf.inode.start == nil { + pf.inode.start = new(uint16) + pf.inode.end = new(uint16) + } + *pf.inode.start = uint16(ports[0]) + *pf.inode.end = uint16(ports[1]) + return n, nil +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. @@ -506,12 +622,15 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine // Add ip_forward. "ip_forward": newIPForwardingInode(ctx, msrc, s), + // Allow for configurable ephemeral port ranges. Note that this + // controls ports for both IPv4 and IPv6 sockets. + "ip_local_port_range": newPortRangeInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the // actual netstack behavior or any empty file, all of these // files will have mode 0444 (read-only for all users). - "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")), "ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")), "ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")), "ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")), diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index fd7823daa..fb274b78e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -17,6 +17,7 @@ package proc import ( "bytes" "fmt" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -69,17 +70,17 @@ func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), + "ip_local_port_range": fs.newInode(ctx, root, 0644, &portRange{stack: stack}), + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), @@ -421,3 +422,68 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs } return n, nil } + +// portRange implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/ip_local_port_range. +// +// +stateify savable +type portRange struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` + + // start and end store the port range. We must save/restore this here, + // since a netstack instance is created on restore. + start *uint16 + end *uint16 +} + +var _ vfs.WritableDynamicBytesSource = (*portRange)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error { + if pr.start == nil { + start, end := pr.stack.PortRange() + pr.start = &start + pr.end = &end + } + _, err := fmt.Fprintf(buf, "%d %d\n", *pr.start, *pr.end) + return err +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is + // large. + src = src.TakeFirst(usermem.PageSize - 1) + + ports := make([]int32, 2) + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) + if err != nil { + return 0, err + } + + // Port numbers must be uint16s. + if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { + return 0, syserror.EINVAL + } + + if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { + return 0, err + } + if pr.start == nil { + pr.start = new(uint16) + pr.end = new(uint16) + } + *pr.start = uint16(ports[0]) + *pr.end = uint16(ports[1]) + return n, nil +} diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index f31277d30..6b71bd3a9 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -93,6 +93,14 @@ type Stack interface { // SetForwarding enables or disables packet forwarding between NICs. SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error + + // PortRange returns the UDP and TCP inclusive range of ephemeral ports + // used in both IPv4 and IPv6. + PortRange() (uint16, uint16) + + // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range + // (inclusive). + SetPortRange(start uint16, end uint16) error } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 9ebeba8a3..03e2608c2 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -164,3 +164,15 @@ func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable b s.IPForwarding = enable return nil } + +// PortRange implements inet.Stack.PortRange. +func (*TestStack) PortRange() (uint16, uint16) { + // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). + return 32768, 28232 +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (*TestStack) SetPortRange(start uint16, end uint16) error { + // No-op. + return nil +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index e6323244c..5bcf92e14 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -504,3 +504,14 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } + +// PortRange implements inet.Stack.PortRange. +func (*Stack) PortRange() (uint16, uint16) { + // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). + return 32768, 28232 +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (*Stack) SetPortRange(start uint16, end uint16) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 71c3bc034..b215067cf 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -478,3 +478,13 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) } return nil } + +// PortRange implements inet.Stack.PortRange. +func (s *Stack) PortRange() (uint16, uint16) { + return s.Stack.PortRange() +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (s *Stack) SetPortRange(start uint16, end uint16) error { + return syserr.TranslateNetstackError(s.Stack.SetPortRange(start, end)).ToError() +} diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 0b9139570..79e564de6 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -51,6 +51,7 @@ var ( ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL) + ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL) ) // TranslateNetstackError converts an error from the tcpip package to a sentry @@ -135,6 +136,8 @@ func TranslateNetstackError(err tcpip.Error) *Error { return ErrBadBuffer case *tcpip.ErrMalformedHeader: return ErrMalformedHeader + case *tcpip.ErrInvalidPortRange: + return ErrInvalidPortRange default: panic(fmt.Sprintf("unknown error %T", err)) } diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go index 3b7cc52f3..5d478ac32 100644 --- a/pkg/tcpip/errors.go +++ b/pkg/tcpip/errors.go @@ -300,6 +300,19 @@ func (*ErrInvalidOptionValue) IgnoreStats() bool { } func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" } +// ErrInvalidPortRange indicates an attempt to set an invalid port range. +// +// +stateify savable +type ErrInvalidPortRange struct{} + +func (*ErrInvalidPortRange) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidPortRange) IgnoreStats() bool { + return true +} +func (*ErrInvalidPortRange) String() string { return "invalid port range" } + // ErrMalformedHeader indicates the operation encountered a malformed header. // // +stateify savable diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 11dbdbbcf..101872b47 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -16,7 +16,6 @@ package ports import ( - "math" "math/rand" "sync/atomic" @@ -24,16 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -const ( - // FirstEphemeral is the first ephemeral port. - FirstEphemeral = 16000 - - // numEphemeralPorts it the mnumber of available ephemeral ports to - // Netstack. - numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 - - anyIPAddress tcpip.Address = "" -) +const anyIPAddress tcpip.Address = "" type portDescriptor struct { network tcpip.NetworkProtocolNumber @@ -83,9 +73,16 @@ func (f Flags) Effective() Flags { // PortManager manages allocating, reserving and releasing ports. type PortManager struct { + // mu protects allocatedPorts. + // LOCK ORDERING: mu > ephemeralMu. mu sync.RWMutex allocatedPorts map[portDescriptor]bindAddresses + // ephemeralMu protects firstEphemeral and numEphemeral. + ephemeralMu sync.RWMutex + firstEphemeral uint16 + numEphemeral uint16 + // hint is used to pick ports ephemeral ports in a stable order for // a given port offset. // @@ -322,7 +319,13 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice // NewPortManager creates new PortManager. func NewPortManager() *PortManager { - return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)} + return &PortManager{ + allocatedPorts: make(map[portDescriptor]bindAddresses), + // Match Linux's default ephemeral range. See: + // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 + firstEphemeral: 32768, + numEphemeral: 28232, + } } // PickEphemeralPort randomly chooses a starting point and iterates over all @@ -330,13 +333,18 @@ func NewPortManager() *PortManager { // is suitable for its needs, and stopping when a port is found or an error // occurs. func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - offset := uint32(rand.Int31n(numEphemeralPorts)) - return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) + s.ephemeralMu.RLock() + firstEphemeral := s.firstEphemeral + numEphemeral := s.numEphemeral + s.ephemeralMu.RUnlock() + + offset := uint16(rand.Int31n(int32(numEphemeral))) + return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } // portHint atomically reads and returns the s.hint value. -func (s *PortManager) portHint() uint32 { - return atomic.LoadUint32(&s.hint) +func (s *PortManager) portHint() uint16 { + return uint16(atomic.LoadUint32(&s.hint)) } // incPortHint atomically increments s.hint by 1. @@ -348,8 +356,13 @@ func (s *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) +func (s *PortManager) PickEphemeralPortStable(offset uint16, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { + s.ephemeralMu.RLock() + firstEphemeral := s.firstEphemeral + numEphemeral := s.numEphemeral + s.ephemeralMu.RUnlock() + + p, err := pickEphemeralPort(s.portHint()+offset, firstEphemeral, numEphemeral, testPort) if err == nil { s.incPortHint() } @@ -361,9 +374,9 @@ func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uin // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - for i := uint32(0); i < count; i++ { - port = uint16(FirstEphemeral + (offset+i)%count) +func pickEphemeralPort(offset, first, count uint16, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { + for i := uint16(0); i < count; i++ { + port = first + (offset+i)%count ok, err := testPort(port) if err != nil { return 0, err @@ -567,3 +580,24 @@ func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, } } } + +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (s *PortManager) PortRange() (uint16, uint16) { + s.ephemeralMu.RLock() + defer s.ephemeralMu.RUnlock() + return s.firstEphemeral, s.firstEphemeral + s.numEphemeral - 1 +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (s *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error { + if start > end { + return &tcpip.ErrInvalidPortRange{} + } + s.ephemeralMu.Lock() + defer s.ephemeralMu.Unlock() + s.firstEphemeral = start + s.numEphemeral = end - start + 1 + return nil +} diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index e70fbb72b..6cfac04b1 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -329,6 +329,7 @@ func TestPortReservation(t *testing.T) { net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} for _, test := range test.actions { + first, _ := pm.PortRange() if test.release { pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue @@ -337,8 +338,8 @@ func TestPortReservation(t *testing.T) { if diff := cmp.Diff(test.want, err); diff != "" { t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + if test.port == 0 && (gotPort == 0 || gotPort < first) { + t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, first) } } }) @@ -346,6 +347,11 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { + const ( + firstEphemeral = 32000 + numEphemeralPorts = 1000 + ) + for _, test := range []struct { name string f func(port uint16) (bool, tcpip.Error) @@ -369,17 +375,17 @@ func TestPickEphemeralPort(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { + if port == firstEphemeral+42 { return true, nil } return false, nil }, - wantPort: FirstEphemeral + 42, + wantPort: firstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { + if port < firstEphemeral { return true, nil } return false, nil @@ -389,6 +395,9 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { + t.Fatalf("failed to set ephemeral port range: %s", err) + } port, err := pm.PickEphemeralPort(test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) @@ -401,6 +410,11 @@ func TestPickEphemeralPort(t *testing.T) { } func TestPickEphemeralPortStable(t *testing.T) { + const ( + firstEphemeral = 32000 + numEphemeralPorts = 1000 + ) + for _, test := range []struct { name string f func(port uint16) (bool, tcpip.Error) @@ -424,17 +438,17 @@ func TestPickEphemeralPortStable(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { + if port == firstEphemeral+42 { return true, nil } return false, nil }, - wantPort: FirstEphemeral + 42, + wantPort: firstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { + if port < firstEphemeral { return true, nil } return false, nil @@ -444,7 +458,10 @@ func TestPickEphemeralPortStable(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { + t.Fatalf("failed to set ephemeral port range: %s", err) + } + portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) port, err := pm.PickEphemeralPortStable(portOffset, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index de94ddfda..53370c354 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -813,6 +813,18 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { return forwardingProtocol.Forwarding() } +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (s *Stack) PortRange() (uint16, uint16) { + return s.PortManager.PortRange() +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (s *Stack) SetPortRange(start uint16, end uint16) tcpip.Error { + return s.PortManager.SetPortRange(start, end) +} + // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. // diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index fcdd032c5..a69d6624d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -105,7 +105,6 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 687b9f459..4836f8adc 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2220,7 +2220,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) h.Write(portBuf) - portOffset := h.Sum32() + portOffset := uint16(h.Sum32()) var twReuse tcpip.TCPTimeWaitReuseOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0128c1f7e..a684f204d 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -33,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -4783,7 +4782,8 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("unknown address type: '%s'", candidateAddressType) } - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { + start, end := s.PortRange() + for i := start; i <= end; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 5371f825c..5399d8106 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2330,13 +2330,15 @@ cc_binary( ], linkstatic = 1, deps = [ + gtest, ":ip_socket_test_util", ":socket_test_util", - "@com_google_absl//absl/strings", - gtest, + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", ], ) diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 73140b2e9..20f1dc305 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -40,6 +40,7 @@ namespace { constexpr const char kProcNet[] = "/proc/net"; constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward"; +constexpr const char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range"; TEST(ProcNetSymlinkTarget, FileMode) { struct stat s; @@ -562,6 +563,42 @@ TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +TEST(ProcSysNetPortRange, CanReadAndWrite) { + int min; + int max; + std::string rangefile = ASSERT_NO_ERRNO_AND_VALUE(GetContents(kRangeFile)); + ASSERT_EQ(rangefile.back(), '\n'); + rangefile.pop_back(); + std::vector<std::string> range = + absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + ASSERT_GT(range.size(), 1); + ASSERT_TRUE(absl::SimpleAtoi(range.front(), &min)); + ASSERT_TRUE(absl::SimpleAtoi(range.back(), &max)); + EXPECT_LE(min, max); + + // If the file isn't writable, there's nothing else to do here. + if (access(kRangeFile, W_OK)) { + return; + } + + constexpr int kSize = 77; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(kRangeFile, O_WRONLY | O_TRUNC, 0)); + max = min + kSize; + const std::string small_range = absl::StrFormat("%d %d", min, max); + ASSERT_THAT(write(fd.get(), small_range.c_str(), small_range.size()), + SyscallSucceedsWithValue(small_range.size())); + + rangefile = ASSERT_NO_ERRNO_AND_VALUE(GetContents(kRangeFile)); + ASSERT_EQ(rangefile.back(), '\n'); + rangefile.pop_back(); + range = absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + ASSERT_GT(range.size(), 1); + ASSERT_TRUE(absl::SimpleAtoi(range.front(), &min)); + ASSERT_TRUE(absl::SimpleAtoi(range.back(), &max)); + EXPECT_EQ(min + kSize, max); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc index 679586530..c35aa2183 100644 --- a/test/syscalls/linux/socket_generic_stress.cc +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -17,29 +17,72 @@ #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/un.h> +#include <unistd.h> #include <array> #include <string> #include "gtest/gtest.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" namespace gvisor { namespace testing { +constexpr char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range"; + +PosixErrorOr<int> NumPorts() { + int min = 0; + int max = 1 << 16; + + // Read the ephemeral range from /proc. + ASSIGN_OR_RETURN_ERRNO(std::string rangefile, GetContents(kRangeFile)); + const std::string err_msg = + absl::StrFormat("%s has invalid content: %s", kRangeFile, rangefile); + if (rangefile.back() != '\n') { + return PosixError(EINVAL, err_msg); + } + rangefile.pop_back(); + std::vector<std::string> range = + absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + if (range.size() < 2 || !absl::SimpleAtoi(range.front(), &min) || + !absl::SimpleAtoi(range.back(), &max)) { + return PosixError(EINVAL, err_msg); + } + + // If we can open as writable, limit the range. + if (!access(kRangeFile, W_OK)) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, + Open(kRangeFile, O_WRONLY | O_TRUNC, 0)); + max = min + 50; + const std::string small_range = absl::StrFormat("%d %d", min, max); + int n = write(fd.get(), small_range.c_str(), small_range.size()); + if (n < 0) { + return PosixError( + errno, + absl::StrFormat("write(%d [%s], \"%s\", %d)", fd.get(), kRangeFile, + small_range.c_str(), small_range.size())); + } + } + return max - min; +} + // Test fixture for tests that apply to pairs of connected sockets. using ConnectStressTest = SocketPairTest; -TEST_P(ConnectStressTest, Reset65kTimes) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(ConnectStressTest, Reset) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + const std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // Send some data to ensure that the connection gets reset and the port gets // released immediately. This avoids either end entering TIME-WAIT. @@ -57,6 +100,24 @@ TEST_P(ConnectStressTest, Reset65kTimes) { } } +// Tests that opening too many connections -- without closing them -- does lead +// to port exhaustion. +TEST_P(ConnectStressTest, TooManyOpen) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + int err_num = 0; + std::vector<std::unique_ptr<SocketPair>> sockets = + std::vector<std::unique_ptr<SocketPair>>(nports); + for (int i = 0; i < nports * 2; i++) { + PosixErrorOr<std::unique_ptr<SocketPair>> socks = NewSocketPair(); + if (!socks.ok()) { + err_num = socks.error().errno_value(); + break; + } + sockets.push_back(std::move(socks).ValueOrDie()); + } + ASSERT_EQ(err_num, EADDRINUSE); +} + INSTANTIATE_TEST_SUITE_P( AllConnectedSockets, ConnectStressTest, ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0), @@ -73,14 +134,40 @@ INSTANTIATE_TEST_SUITE_P( // Test fixture for tests that apply to pairs of connected sockets created with // a persistent listener (if applicable). -using PersistentListenerConnectStressTest = SocketPairTest; +class PersistentListenerConnectStressTest : public SocketPairTest { + protected: + PersistentListenerConnectStressTest() : slept_{false} {} -TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); + // NewSocketSleep is the same as NewSocketPair, but will sleep once (over the + // lifetime of the fixture) and retry if creation fails due to EADDRNOTAVAIL. + PosixErrorOr<std::unique_ptr<SocketPair>> NewSocketSleep() { + // We can't reuse a connection too close in time to its last use, as TCP + // uses the timestamp difference to disambiguate connections. With a + // sufficiently small port range, we'll cycle through too quickly, and TCP + // won't allow for connection reuse. Thus, we sleep the first time + // encountering EADDRINUSE to allow for that difference (1 second in + // gVisor). + PosixErrorOr<std::unique_ptr<SocketPair>> socks = NewSocketPair(); + if (socks.ok()) { + return socks; + } + if (!slept_ && socks.error().errno_value() == EADDRNOTAVAIL) { + absl::SleepFor(absl::Milliseconds(1500)); + slept_ = true; + return NewSocketPair(); + } + return socks; + } - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + private: + bool slept_; +}; + +TEST_P(PersistentListenerConnectStressTest, ShutdownCloseFirst) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep()); ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); if (GetParam().type == SOCK_STREAM) { // Poll the other FD to make sure that we see the FIN from the other @@ -97,12 +184,11 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { } } -TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(PersistentListenerConnectStressTest, ShutdownCloseSecond) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + const std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); if (GetParam().type == SOCK_STREAM) { // Poll the other FD to make sure that we see the FIN from the other @@ -119,12 +205,11 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { } } -TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(PersistentListenerConnectStressTest, Close) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep()); } } @@ -149,7 +234,8 @@ TEST_P(DataTransferStressTest, BigDataTransfer) { // TODO(b/165912341): These are too slow on KVM platform with nested virt. SKIP_IF(GvisorPlatform() == Platform::kKVM); - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + const std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int client_fd = sockets->first_fd(); int server_fd = sockets->second_fd(); |