summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/control/pprof.go6
-rw-r--r--pkg/sentry/devices/tundev/tundev.go4
-rw-r--r--pkg/sentry/fs/proc/sys_net.go95
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go102
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go49
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go2
-rw-r--r--pkg/sentry/inet/inet.go17
-rw-r--r--pkg/sentry/inet/test_stack.go12
-rw-r--r--pkg/sentry/socket/hostinet/stack.go11
-rw-r--r--pkg/sentry/socket/netstack/stack.go14
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/memfd.go1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go8
-rw-r--r--pkg/tcpip/header/ipv4.go5
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go43
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go12
-rw-r--r--pkg/tcpip/stack/ndp_test.go8
-rw-r--r--pkg/tcpip/stack/nic.go60
-rw-r--r--pkg/tcpip/stack/stack.go21
-rw-r--r--pkg/tcpip/stack/stack_test.go193
-rw-r--r--pkg/tcpip/tests/integration/BUILD21
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go274
-rw-r--r--pkg/tcpip/transport/tcp/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go13
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go15
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go33
-rw-r--r--pkg/tcpip/transport/tcp/rack.go82
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/snd.go40
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go74
-rw-r--r--pkg/test/dockerutil/profile.go37
-rw-r--r--pkg/test/dockerutil/profile_test.go13
33 files changed, 1093 insertions, 210 deletions
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 663e51989..2bf3c45e1 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -49,6 +49,9 @@ type ProfileOpts struct {
// - dump out the stack trace of current go routines.
// sentryctl -pid <pid> pprof-goroutine
type Profile struct {
+ // Kernel is the kernel under profile. It's immutable.
+ Kernel *kernel.Kernel
+
// mu protects the fields below.
mu sync.Mutex
@@ -57,9 +60,6 @@ type Profile struct {
// traceFile is the current execution trace output file.
traceFile *fd.FD
-
- // Kernel is the kernel under profile.
- Kernel *kernel.Kernel
}
// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 852ec3c5c..a40625e19 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -160,8 +160,8 @@ func (fd *tunFD) EventUnregister(e *waiter.Entry) {
fd.device.EventUnregister(e)
}
-// isNetTunSupported returns whether /dev/net/tun device is supported for s.
-func isNetTunSupported(s inet.Stack) bool {
+// IsNetTunSupported returns whether /dev/net/tun device is supported for s.
+func IsNetTunSupported(s inet.Stack) bool {
_, ok := s.(*netstack.Stack)
return ok
}
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 702fdd392..8615b60f0 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -272,6 +272,96 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque
return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled)
}
+// +stateify savable
+type tcpRecovery struct {
+ fsutil.SimpleFileInode
+
+ stack inet.Stack `state:"wait"`
+ recovery inet.TCPLossRecovery
+}
+
+func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ts := &tcpRecovery{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, ts, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*tcpRecovery) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (r *tcpRecovery) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &tcpRecoveryFile{
+ tcpRecovery: r,
+ stack: r.stack,
+ }), nil
+}
+
+// +stateify savable
+type tcpRecoveryFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ tcpRecovery *tcpRecovery
+
+ stack inet.Stack `state:"wait"`
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *tcpRecoveryFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ recovery, err := f.stack.TCPRecovery()
+ if err != nil {
+ return 0, err
+ }
+ f.tcpRecovery.recovery = recovery
+ s := fmt.Sprintf("%d\n", f.tcpRecovery.recovery)
+ n, err := dst.CopyOut(ctx, []byte(s))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+ f.tcpRecovery.recovery = inet.TCPLossRecovery(v)
+ if err := f.tcpRecovery.stack.SetTCPRecovery(f.tcpRecovery.recovery); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
@@ -351,6 +441,11 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem)
}
+ // Add tcp_recovery.
+ if _, err := s.TCPRecovery(); err == nil {
+ contents["tcp_recovery"] = newTCPRecoveryInode(ctx, msrc, s)
+ }
+
d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 420e8efe2..db6bed4f6 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -184,6 +184,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
+
// Set offset to file size if the fd was opened with O_APPEND.
if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
// Holding d.metadataMu is sufficient for reading d.size.
@@ -194,70 +195,79 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
return 0, offset, err
}
src = src.TakeFirst64(limit)
- n, err := fd.pwriteLocked(ctx, src, offset, opts)
- return n, offset + n, err
-}
-// Preconditions: fd.dentry().metatdataMu must be locked.
-func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- d := fd.dentry()
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
// d.metadataMu (recursively).
d.touchCMtimeLocked()
}
- if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
- // Write dirty cached pages that will be touched by the write back to
- // the remote file.
- if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
- return 0, err
- }
- // Remove touched pages from the cache.
- pgstart := usermem.PageRoundDown(uint64(offset))
- pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes()))
- if !ok {
- return 0, syserror.EINVAL
- }
- mr := memmap.MappableRange{pgstart, pgend}
- var freed []memmap.FileRange
- d.dataMu.Lock()
- cseg := d.cache.LowerBoundSegment(mr.Start)
- for cseg.Ok() && cseg.Start() < mr.End {
- cseg = d.cache.Isolate(cseg, mr)
- freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
- cseg = d.cache.Remove(cseg).NextSegment()
- }
- d.dataMu.Unlock()
- // Invalidate mappings of removed pages.
- d.mapsMu.Lock()
- d.mappings.Invalidate(mr, memmap.InvalidateOpts{})
- d.mapsMu.Unlock()
- // Finally free pages removed from the cache.
- mf := d.fs.mfp.MemoryFile()
- for _, freedFR := range freed {
- mf.DecRef(freedFR)
- }
- }
+
rw := getDentryReadWriter(ctx, d, offset)
+ defer putDentryReadWriter(rw)
+
if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ if err := fd.writeCache(ctx, d, offset, src); err != nil {
+ return 0, offset, err
+ }
+
// Require the write to go to the remote file.
rw.direct = true
}
+
n, err := src.CopyInTo(ctx, rw)
- putDentryReadWriter(rw)
- if n != 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
- // Write dirty cached pages touched by the write back to the remote
- // file.
+ if err != nil {
+ return n, offset, err
+ }
+ if n > 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ // Write dirty cached pages touched by the write back to the remote file.
if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
- return 0, err
+ return n, offset, err
}
// Request the remote filesystem to sync the remote file.
- if err := d.handle.file.fsync(ctx); err != nil {
- return 0, err
+ if err := d.handle.sync(ctx); err != nil {
+ return n, offset, err
}
}
- return n, err
+ return n, offset + n, nil
+}
+
+func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64, src usermem.IOSequence) error {
+ // Write dirty cached pages that will be touched by the write back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return err
+ }
+
+ // Remove touched pages from the cache.
+ pgstart := usermem.PageRoundDown(uint64(offset))
+ pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes()))
+ if !ok {
+ return syserror.EINVAL
+ }
+ mr := memmap.MappableRange{pgstart, pgend}
+ var freed []memmap.FileRange
+
+ d.dataMu.Lock()
+ cseg := d.cache.LowerBoundSegment(mr.Start)
+ for cseg.Ok() && cseg.Start() < mr.End {
+ cseg = d.cache.Isolate(cseg, mr)
+ freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ cseg = d.cache.Remove(cseg).NextSegment()
+ }
+ d.dataMu.Unlock()
+
+ // Invalidate mappings of removed pages.
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(mr, memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+
+ // Finally free pages removed from the cache.
+ mf := d.fs.mfp.MemoryFile()
+ for _, freedFR := range freed {
+ mf.DecRef(freedFR)
+ }
+ return nil
}
// Write implements vfs.FileDescriptionImpl.Write.
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 6dac2afa4..b71778128 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -55,7 +55,8 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]*kernfs.Dentry{
"ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
- "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
+ "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
@@ -207,3 +208,49 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset
*d.enabled = v != 0
return n, d.stack.SetTCPSACKEnabled(*d.enabled)
}
+
+// tcpRecoveryData implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/ipv4/tcp_recovery.
+//
+// +stateify savable
+type tcpRecoveryData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+}
+
+var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ recovery, err := d.stack.TCPRecovery()
+ if err != nil {
+ return err
+ }
+
+ buf.WriteString(fmt.Sprintf("%d\n", recovery))
+ return nil
+}
+
+func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit the amount of memory allocated.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+ if err := d.stack.SetTCPRecovery(inet.TCPLossRecovery(v)); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index fb77f95cc..065812065 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -566,7 +566,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if replaced != nil {
newParentDir.removeChildLocked(replaced)
if replaced.inode.isDir() {
- newParentDir.inode.decLinksLocked(ctx) // from replaced's ".."
+ // Remove links for replaced/. and replaced/..
+ replaced.inode.decLinksLocked(ctx)
+ newParentDir.inode.decLinksLocked(ctx)
}
replaced.inode.decLinksLocked(ctx)
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 68e615e8b..4681a2f52 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -558,6 +558,8 @@ func (i *inode) direntType() uint8 {
return linux.DT_LNK
case *socketFile:
return linux.DT_SOCK
+ case *namedPipe:
+ return linux.DT_FIFO
case *deviceFile:
switch impl.kind {
case vfs.BlockDevice:
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 2916a0644..c0b4831d1 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -56,6 +56,12 @@ type Stack interface {
// settings.
SetTCPSACKEnabled(enabled bool) error
+ // TCPRecovery returns the TCP loss detection algorithm.
+ TCPRecovery() (TCPLossRecovery, error)
+
+ // SetTCPRecovery attempts to change TCP loss detection algorithm.
+ SetTCPRecovery(recovery TCPLossRecovery) error
+
// Statistics reports stack statistics.
Statistics(stat interface{}, arg string) error
@@ -189,3 +195,14 @@ type StatSNMPUDP [8]uint64
// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp.
type StatSNMPUDPLite [8]uint64
+
+// TCPLossRecovery indicates TCP loss detection and recovery methods to use.
+type TCPLossRecovery int32
+
+// Loss recovery constants from include/net/tcp.h which are used to set
+// /proc/sys/net/ipv4/tcp_recovery.
+const (
+ TCP_RACK_LOSS_DETECTION TCPLossRecovery = 1 << iota
+ TCP_RACK_STATIC_REO_WND
+ TCP_RACK_NO_DUPTHRESH
+)
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index d8961fc94..9771f01fc 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -25,6 +25,7 @@ type TestStack struct {
TCPRecvBufSize TCPBufferSize
TCPSendBufSize TCPBufferSize
TCPSACKFlag bool
+ Recovery TCPLossRecovery
}
// NewTestStack returns a TestStack with no network interfaces. The value of
@@ -91,6 +92,17 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
return nil
}
+// TCPRecovery implements Stack.TCPRecovery.
+func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) {
+ return s.Recovery, nil
+}
+
+// SetTCPRecovery implements Stack.SetTCPRecovery.
+func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error {
+ s.Recovery = recovery
+ return nil
+}
+
// Statistics implements inet.Stack.Statistics.
func (s *TestStack) Statistics(stat interface{}, arg string) error {
return nil
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index a48082631..fda3dcb35 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -53,6 +53,7 @@ type Stack struct {
interfaceAddrs map[int32][]inet.InterfaceAddr
routes []inet.Route
supportsIPv6 bool
+ tcpRecovery inet.TCPLossRecovery
tcpRecvBufSize inet.TCPBufferSize
tcpSendBufSize inet.TCPBufferSize
tcpSACKEnabled bool
@@ -350,6 +351,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserror.EACCES
}
+// TCPRecovery implements inet.Stack.TCPRecovery.
+func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
+ return s.tcpRecovery, nil
+}
+
+// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
+func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+ return syserror.EACCES
+}
+
// getLine reads one line from proc file, with specified prefix.
// The last argument, withHeader, specifies if it contains line header.
func getLine(f *os.File, prefix string, withHeader bool) string {
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 67737ae87..f0fe18684 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -207,6 +207,20 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
}
+// TCPRecovery implements inet.Stack.TCPRecovery.
+func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
+ var recovery tcp.Recovery
+ if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+ return inet.TCPLossRecovery(recovery), nil
+}
+
+// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
+func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError()
+}
+
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
switch stats := stat.(type) {
diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go
index 519583e4e..c4c0f9e0a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/memfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go
@@ -51,6 +51,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if err != nil {
return 0, nil, err
}
+ defer file.DecRef(t)
fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
CloseOnExec: cloExec,
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 16f59fce9..75bfa2c79 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -347,6 +347,11 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
} else {
spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
+ if spliceN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the error and exit the loop.
+ err = nil
+ break
+ }
n += spliceN
if err == syserror.ErrWouldBlock && !nonBlock {
err = dw.waitForBoth(t)
@@ -367,8 +372,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
}
if readN == 0 && err == io.EOF {
- // We reached the end of the file. Eat the
- // error and exit the loop.
+ // We reached the end of the file. Eat the error and exit the loop.
err = nil
break
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 62ac932bb..d0d1efd0d 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -101,6 +101,11 @@ const (
// IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
+ // IPv4AllSystems is the all systems IPv4 multicast address as per
+ // IANA's IPv4 Multicast Address Space Registry. See
+ // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml.
+ IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01"
+
// IPv4Broadcast is the broadcast address of the IPv4 procotol.
IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index d5f5d38f7..6c4f0ae3e 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -52,27 +52,25 @@ const (
)
type endpoint struct {
- nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
- linkEP stack.LinkEndpoint
- dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
- protocol *protocol
- stack *stack.Stack
+ nicID tcpip.NICID
+ id stack.NetworkEndpointID
+ prefixLen int
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ protocol *protocol
+ stack *stack.Stack
}
// NewEndpoint creates a new ipv4 endpoint.
func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
- dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
- protocol: p,
- stack: st,
+ nicID: nicID,
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ protocol: p,
+ stack: st,
}
return e, nil
@@ -442,7 +440,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
var ready bool
var err error
- pkt.Data, ready, err = e.fragmentation.Process(
+ pkt.Data, ready, err = e.protocol.fragmentation.Process(
+ // As per RFC 791 section 2.3, the identification value is unique
+ // for a source-destination pair and protocol.
fragmentation.FragmentID{
Source: h.SourceAddress(),
Destination: h.DestinationAddress(),
@@ -484,6 +484,8 @@ type protocol struct {
// uint8 portion of it is meaningful and it must be accessed
// atomically.
defaultTTL uint32
+
+ fragmentation *fragmentation.Fragmentation
}
// Number returns the ipv4 protocol number.
@@ -605,5 +607,10 @@ func NewProtocol() stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
+ return &protocol{
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index a0a5c9c01..4a0b53c45 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -51,7 +51,6 @@ type endpoint struct {
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
protocol *protocol
}
@@ -342,7 +341,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
var ready bool
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
- pkt.Data, ready, err = e.fragmentation.Process(
+ pkt.Data, ready, err = e.protocol.fragmentation.Process(
// IPv6 ignores the Protocol field since the ID only needs to be unique
// across source-destination pairs, as per RFC 8200 section 4.5.
fragmentation.FragmentID{
@@ -445,7 +444,8 @@ type protocol struct {
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful and it must be accessed
// atomically.
- defaultTTL uint32
+ defaultTTL uint32
+ fragmentation *fragmentation.Fragmentation
}
// Number returns the ipv6 protocol number.
@@ -478,7 +478,6 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
}, nil
}
@@ -606,5 +605,8 @@ func calculateMTU(mtu uint32) uint32 {
// NewProtocol returns an IPv6 network protocol.
func NewProtocol() stack.NetworkProtocol {
- return &protocol{defaultTTL: DefaultTTL}
+ return &protocol{
+ defaultTTL: DefaultTTL,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 644ba7c33..5d286ccbc 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1689,13 +1689,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix)
AddressWithPrefix: item,
}
- for _, i := range list {
- if i == protocolAddress {
- return true
- }
- }
-
- return false
+ return containsAddr(list, protocolAddress)
}
// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f21066fce..eaaf756cd 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -217,6 +217,11 @@ func (n *NIC) disableLocked() *tcpip.Error {
}
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
// The address may have already been removed.
if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
return err
@@ -255,6 +260,13 @@ func (n *NIC) enable() *tcpip.Error {
if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil {
+ return err
+ }
}
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
@@ -609,6 +621,9 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// If none exists a temporary one may be created if we are in promiscuous mode
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
n.mu.RLock()
@@ -633,6 +648,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
}
}
+ // Check if address is a broadcast address for the endpoint's network.
+ //
+ // Only IPv4 has a notion of broadcast addresses.
+ if protocol == header.IPv4ProtocolNumber {
+ if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
+
// A usable reference was not found, create a temporary one if requested by
// the caller or if the address is found in the NIC's subnets.
createTempEP := spoofingOrPromiscuous
@@ -670,8 +695,34 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
+// getRefForBroadcastLocked returns an endpoint where address is the IPv4
+// broadcast address for the endpoint's network.
+//
+// n.mu MUST be read locked.
+func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+ for _, ref := range n.mu.endpoints {
+ // Only IPv4 has a notion of broadcast addresses.
+ if ref.protocol != header.IPv4ProtocolNumber {
+ continue
+ }
+
+ addr := ref.addrWithPrefix()
+ subnet := addr.Subnet()
+ if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ return ref
+ }
+ }
+
+ return nil
+}
+
/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
/// and returns a temporary endpoint.
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
+//
+// n.mu must be write locked.
func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
@@ -685,6 +736,15 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
n.removeEndpointLocked(ref)
}
+ // Check if address is a broadcast address for an endpoint's network.
+ //
+ // Only IPv4 has a notion of broadcast addresses.
+ if protocol == header.IPv4ProtocolNumber {
+ if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ return ref
+ }
+ }
+
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3f07e4159..5b19c5d59 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -73,6 +73,16 @@ type TCPCubicState struct {
WEst float64
}
+// TCPRACKState is used to hold a copy of the internal RACK state when the
+// TCPProbeFunc is invoked.
+type TCPRACKState struct {
+ XmitTime time.Time
+ EndSequence seqnum.Value
+ FACK seqnum.Value
+ RTT time.Duration
+ Reord bool
+}
+
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
type TCPEndpointID struct {
// LocalPort is the local port associated with the endpoint.
@@ -212,6 +222,9 @@ type TCPSenderState struct {
// Cubic holds the state related to CUBIC congestion control.
Cubic TCPCubicState
+
+ // RACKState holds the state related to RACK loss detection algorithm.
+ RACKState TCPRACKState
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
@@ -1972,8 +1985,8 @@ func generateRandInt64() int64 {
// FindNetworkEndpoint returns the network endpoint for the given address.
func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
for _, nic := range s.nics {
id := NetworkEndpointID{address}
@@ -1992,8 +2005,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
// FindNICNameFromID returns the name of the nic for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index f22062889..0b6deda02 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -277,6 +277,17 @@ func (l *linkEPWithMockedAttach) isAttached() bool {
return l.attached
}
+// Checks to see if list contains an address.
+func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool {
+ for _, i := range list {
+ if i == item {
+ return true
+ }
+ }
+
+ return false
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -1704,7 +1715,7 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub
// Trying the next address should always fail since it is outside the range.
if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = 0", fakeNetNumber, tcpip.Address(addrBytes), gotNicID)
}
}
@@ -3089,6 +3100,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
const nicID = 1
+ broadcastAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
e := loopback.New()
s := stack.New(stack.Options{
@@ -3099,49 +3117,41 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
}
- allStackAddrs := s.AllAddresses()
- allNICAddrs, ok := allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
}
// Enabling the NIC should add the IPv4 broadcast address.
if err := s.EnableNIC(nicID); err != nil {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 1 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
- }
- want := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 32,
- },
- }
- if allNICAddrs[0] != want {
- t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if !containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr)
+ }
}
// Disabling the NIC should remove the IPv4 broadcast address.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
}
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
}
}
@@ -3189,50 +3199,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
}
}
-func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ }{
+ {
+ name: "IPv6 All-Nodes",
+ proto: header.IPv6ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "IPv4 All-Systems",
+ proto: header.IPv4ProtocolNumber,
+ addr: header.IPv4AllSystems,
+ },
}
- // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
- // not been enabled yet.
- isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
- // The all-nodes multicast group should be joined when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
- }
+ // Should not be in the multicast group yet because the NIC has not been
+ // enabled yet.
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
- // The all-nodes multicast group should be left when the NIC is disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // The multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // Leaving the group before disabling the NIC should not cause an error.
+ if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil {
+ t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err)
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+ })
}
}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
new file mode 100644
index 000000000..7fff30462
--- /dev/null
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "integration_test",
+ size = "small",
+ srcs = ["multicast_broadcast_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
new file mode 100644
index 000000000..d9b2d147a
--- /dev/null
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -0,0 +1,274 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const defaultMTU = 1280
+
+// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
+// multicast or broadcast address.
+func TestIncomingMulticastAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ remotePort = 5555
+ localPort = 80
+ ttl = 255
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ // Local IPv4 subnet: 192.168.1.58/24
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+
+ // Local IPv6 subnet: 200a::1/64
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+
+ // Remote addrs.
+ remoteIPv4Addr := tcpip.Address("\x64\x0a\x7b\x18")
+ remoteIPv6Addr := tcpip.Address("\x20\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ totalLen := header.IPv4MinimumSize + payloadLen
+ hdr := buffer.NewPrependable(totalLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(udp.ProtocolNumber),
+ TTL: ttl,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLen),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ tests := []struct {
+ name string
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
+ }{
+ {
+ name: "IPv4 unicast binding to unicast",
+ bindAddr: ipv4Addr.Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 unicast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 unicast binding to wildcard",
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 directed broadcast binding to subnet broadcast",
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 directed broadcast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 directed broadcast binding to wildcard",
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 broadcast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 broadcast binding to subnet broadcast",
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 broadcast binding to wildcard",
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 all-systems multicast binding to all-systems multicast",
+ bindAddr: header.IPv4AllSystems,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 all-systems multicast binding to wildcard",
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 all-systems multicast binding to unicast",
+ bindAddr: ipv4Addr.Address,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: false,
+ },
+
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 unicast binding to wildcard",
+ dstAddr: ipv6Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv6 broadcast-like address binding to wildcard",
+ dstAddr: ipv6SubnetBcast,
+ expectRx: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ }
+ ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ }
+
+ var netproto tcpip.NetworkProtocolNumber
+ var rxUDP func(*channel.Endpoint, tcpip.Address)
+ switch l := len(test.dstAddr); l {
+ case header.IPv4AddressSize:
+ netproto = header.IPv4ProtocolNumber
+ rxUDP = rxIPv4UDP
+ case header.IPv6AddressSize:
+ netproto = header.IPv6ProtocolNumber
+ rxUDP = rxIPv6UDP
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", bindAddr, err)
+ }
+
+ rxUDP(e, test.dstAddr)
+ if gotPayload, _, err := ep.Read(nil); test.expectRx {
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index e860ee484..234fb95ce 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -40,6 +40,8 @@ go_library(
"endpoint_state.go",
"forwarder.go",
"protocol.go",
+ "rack.go",
+ "rack_state.go",
"rcv.go",
"rcv_state.go",
"reno.go",
@@ -83,6 +85,7 @@ go_test(
"dual_stack_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
+ "tcp_rack_test.go",
"tcp_sack_test.go",
"tcp_test.go",
"tcp_timestamp_test.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e00e5526..913ea6535 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -521,7 +521,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(timeStampOffset()),
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
MSS: mssForRoute(&s.route),
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6e5e55b6f..8dd759ba2 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1166,13 +1166,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
return nil
}
-// handleSegment handles a given segment and notifies the worker goroutine if
-// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
- // Invoke the tcp probe if installed.
+func (e *endpoint) probeSegment() {
if e.probe != nil {
e.probe(e.completeState())
}
+}
+
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed. The tcp probe function will update
+ // the TCPEndpointState after the segment is processed.
+ defer e.probeSegment()
if s.flagIsSet(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 682687ebe..39ea38fe6 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2692,15 +2692,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.tsOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(offset uint32) uint32 {
- now := time.Now()
- return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
+ return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
@@ -2843,6 +2842,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
WEst: cubic.wEst,
}
}
+
+ rc := e.snd.rc
+ s.Sender.RACKState = stack.TCPRACKState{
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ }
return s
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b34e47bbd..d9abb8d94 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -80,6 +80,25 @@ const (
// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
+// Recovery is used by stack.(*Stack).TransportProtocolOption to
+// set loss detection algorithm in TCP.
+type Recovery int32
+
+const (
+ // RACKLossDetection indicates RACK is used for loss detection and
+ // recovery.
+ RACKLossDetection Recovery = 1 << iota
+
+ // RACKStaticReoWnd indicates the reordering window should not be
+ // adjusted when DSACK is received.
+ RACKStaticReoWnd
+
+ // RACKNoDupTh indicates RACK should not consider the classic three
+ // duplicate acknowledgements rule to mark the segments as lost. This
+ // is used when reordering is not detected.
+ RACKNoDupTh
+)
+
// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
// enable/disable Nagle's algorithm in TCP.
type DelayEnabled bool
@@ -161,6 +180,7 @@ func (s *synRcvdCounter) Threshold() uint64 {
type protocol struct {
mu sync.RWMutex
sackEnabled bool
+ recovery Recovery
delayEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
@@ -280,6 +300,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case Recovery:
+ p.mu.Lock()
+ p.recovery = Recovery(v)
+ p.mu.Unlock()
+ return nil
+
case DelayEnabled:
p.mu.Lock()
p.delayEnabled = bool(v)
@@ -394,6 +420,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
+ case *Recovery:
+ p.mu.RLock()
+ *v = Recovery(p.recovery)
+ p.mu.RUnlock()
+ return nil
+
case *DelayEnabled:
p.mu.RLock()
*v = DelayEnabled(p.delayEnabled)
@@ -535,6 +567,7 @@ func NewProtocol() stack.TransportProtocol {
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
+ recovery: RACKLossDetection,
}
p.dispatcher.init(runtime.GOMAXPROCS(0))
return &p
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
new file mode 100644
index 000000000..d969ca23a
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -0,0 +1,82 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// RACK is a loss detection algorithm used in TCP to detect packet loss and
+// reordering using transmission timestamp of the packets instead of packet or
+// sequence counts. To use RACK, SACK should be enabled on the connection.
+
+// rackControl stores the rack related fields.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1
+//
+// +stateify savable
+type rackControl struct {
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
+
+ // endSequence is the ending TCP sequence number of rackControl.seg.
+ endSequence seqnum.Value
+
+ // fack is the highest selectively or cumulatively acknowledged
+ // sequence.
+ fack seqnum.Value
+
+ // rtt is the RTT of the most recently delivered packet on the
+ // connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ rtt time.Duration
+}
+
+// Update will update the RACK related fields when an ACK has been received.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) {
+ rtt := time.Now().Sub(seg.xmitTime)
+
+ // If the ACK is for a retransmitted packet, do not update if it is a
+ // spurious inference which is determined by below checks:
+ // 1. When Timestamping option is available, if the TSVal is less than the
+ // transmit time of the most recent retransmitted packet.
+ // 2. When RTT calculated for the packet is less than the smoothed RTT
+ // for the connection.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // step 2
+ if seg.xmitCount > 1 {
+ if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ return
+ }
+ }
+ if rtt < srtt {
+ return
+ }
+ }
+
+ rc.rtt = rtt
+ // Update rc.xmitTime and rc.endSequence to the transmit time and
+ // ending sequence number of the packet which has been acknowledged
+ // most recently.
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
+ rc.xmitTime = seg.xmitTime
+ rc.endSequence = endSeq
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
new file mode 100644
index 000000000..c9dc7e773
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveXmitTime is invoked by stateify.
+func (rc *rackControl) saveXmitTime() unixTime {
+ return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (rc *rackControl) loadXmitTime(unix unixTime) {
+ rc.xmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 5862c32f2..c55589c45 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -191,6 +191,10 @@ type sender struct {
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
+
+ // rc has the fields needed for implementing RACK loss detection
+ // algorithm.
+ rc rackControl
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -1272,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
-func (s *sender) handleRcvdSegment(seg *segment) {
+func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && seg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
if s.ep.sackPermitted {
- for _, sb := range seg.parsedOptions.SACKBlocks {
+ for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
// * SACK block acks data after the ack number in the
@@ -1299,27 +1303,27 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
- seg.hasNewSACKInfo = true
+ rcvdSeg.hasNewSACKInfo = true
}
}
s.SetPipe()
}
// Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(seg)
+ rtx := s.checkDuplicateAck(rcvdSeg)
// Stash away the current window size.
- s.sndWnd = seg.window
+ s.sndWnd = rcvdSeg.window
- ack := seg.ackNumber
+ ack := rcvdSeg.ackNumber
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
// unacknowledged segment.
- if s.zeroWindowProbing && seg.window > 0 &&
+ if s.zeroWindowProbing && rcvdSeg.window > 0 &&
(ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
s.disableZeroWindowProbing()
}
@@ -1344,10 +1348,10 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
- elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
@@ -1361,6 +1365,9 @@ func (s *sender) handleRcvdSegment(seg *segment) {
ackLeft := acked
originalOutstanding := s.outstanding
+ s.rtt.Lock()
+ srtt := s.rtt.srtt
+ s.rtt.Unlock()
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1380,6 +1387,11 @@ func (s *sender) handleRcvdSegment(seg *segment) {
s.writeNext = seg.Next()
}
+ // Update the RACK fields if SACK is enabled.
+ if s.ep.sackPermitted {
+ s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset)
+ }
+
s.writeList.Remove(seg)
// if SACK is enabled then Only reduce outstanding if
@@ -1435,7 +1447,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
s.sendData()
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
new file mode 100644
index 000000000..e03f101e8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+// TestRACKUpdate tests the RACK related fields are updated when an ACK is
+// received on a SACK enabled connection.
+func TestRACKUpdate(t *testing.T) {
+ const maxPayload = 10
+ const tsOptionSize = 12
+ const maxTCPOptionSize = 40
+
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ var xmitTime time.Time
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.RACKState is what we expect.
+ if state.Sender.RACKState.XmitTime.Before(xmitTime) {
+ t.Fatalf("RACK transmit time failed to update when an ACK is received")
+ }
+
+ gotSeq := state.Sender.RACKState.EndSequence
+ wantSeq := state.Sender.SndNxt
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ if state.Sender.RACKState.RTT == 0 {
+ t.Fatalf("RACK RTT failed to update when an ACK is received")
+ }
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ xmitTime = time.Now()
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ c.SendAck(790, bytesRead)
+ time.Sleep(200 * time.Millisecond)
+}
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
index 1fab33083..f0396ef24 100644
--- a/pkg/test/dockerutil/profile.go
+++ b/pkg/test/dockerutil/profile.go
@@ -49,17 +49,16 @@ type Profile interface {
// should have --profile set as an option in /etc/docker/daemon.json in
// order for profiling to work with Pprof.
type Pprof struct {
- BasePath string // path to put profiles
- BlockProfile bool
- CPUProfile bool
- GoRoutineProfile bool
- HeapProfile bool
- MutexProfile bool
- Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
- shouldRun bool
- cmd *exec.Cmd
- stdout io.ReadCloser
- stderr io.ReadCloser
+ BasePath string // path to put profiles
+ BlockProfile bool
+ CPUProfile bool
+ HeapProfile bool
+ MutexProfile bool
+ Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
+ shouldRun bool
+ cmd *exec.Cmd
+ stdout io.ReadCloser
+ stderr io.ReadCloser
}
// MakePprofFromFlags makes a Pprof profile from flags.
@@ -68,13 +67,12 @@ func MakePprofFromFlags(c *Container) *Pprof {
return nil
}
return &Pprof{
- BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
- BlockProfile: *pprofBlock,
- CPUProfile: *pprofCPU,
- GoRoutineProfile: *pprofGo,
- HeapProfile: *pprofHeap,
- MutexProfile: *pprofMutex,
- Duration: *duration,
+ BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
+ BlockProfile: *pprofBlock,
+ CPUProfile: *pprofCPU,
+ HeapProfile: *pprofHeap,
+ MutexProfile: *pprofMutex,
+ Duration: *duration,
}
}
@@ -138,9 +136,6 @@ func (p *Pprof) makeProfileArgs(c *Container) []string {
if p.CPUProfile {
ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof")))
}
- if p.GoRoutineProfile {
- ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof")))
- }
if p.HeapProfile {
ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof")))
}
diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go
index b7b4d7618..8c4ffe483 100644
--- a/pkg/test/dockerutil/profile_test.go
+++ b/pkg/test/dockerutil/profile_test.go
@@ -51,13 +51,12 @@ func TestPprof(t *testing.T) {
{
name: "All",
pprof: Pprof{
- BasePath: basePath,
- BlockProfile: true,
- CPUProfile: true,
- GoRoutineProfile: true,
- HeapProfile: true,
- MutexProfile: true,
- Duration: 2 * time.Second,
+ BasePath: basePath,
+ BlockProfile: true,
+ CPUProfile: true,
+ HeapProfile: true,
+ MutexProfile: true,
+ Duration: 2 * time.Second,
},
expectedFiles: []string{block, cpu, goprofle, heap, mutex},
},