summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/ptrace_amd64.go11
-rw-r--r--pkg/abi/linux/ptrace_arm64.go11
-rw-r--r--pkg/log/BUILD1
-rw-r--r--pkg/log/log.go3
-rw-r--r--pkg/sentry/fs/copy_up.go11
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go1
-rw-r--r--pkg/sentry/fs/proc/sys_net.go121
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go5
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go128
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go57
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go110
-rw-r--r--pkg/sentry/fsimpl/host/host.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go5
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go15
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go5
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go78
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go5
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go5
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go17
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go68
-rw-r--r--pkg/sentry/inet/inet.go8
-rw-r--r--pkg/sentry/inet/test_stack.go12
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go120
-rw-r--r--pkg/sentry/socket/hostinet/stack.go11
-rw-r--r--pkg/sentry/socket/netstack/netstack.go161
-rw-r--r--pkg/sentry/socket/netstack/stack.go14
-rw-r--r--pkg/sentry/vfs/anonfs.go5
-rw-r--r--pkg/sentry/vfs/dentry.go8
-rw-r--r--pkg/sentry/vfs/filesystem.go9
-rw-r--r--pkg/sentry/vfs/mount.go25
-rw-r--r--pkg/sync/mutex_unsafe.go1
-rw-r--r--pkg/sync/rwmutex_unsafe.go6
-rw-r--r--pkg/syserr/netstack.go3
-rw-r--r--pkg/tcpip/buffer/view.go10
-rw-r--r--pkg/tcpip/buffer/view_test.go46
-rw-r--r--pkg/tcpip/checker/checker.go18
-rw-r--r--pkg/tcpip/errors.go13
-rw-r--r--pkg/tcpip/header/checksum.go53
-rw-r--r--pkg/tcpip/header/checksum_test.go113
-rw-r--r--pkg/tcpip/header/icmpv4.go5
-rw-r--r--pkg/tcpip/header/icmpv6.go17
-rw-r--r--pkg/tcpip/header/parse/parse.go30
-rw-r--r--pkg/tcpip/header/tcp.go28
-rw-r--r--pkg/tcpip/header/tcp_test.go20
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go6
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go6
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go14
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation.go6
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go16
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go4
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_test.go2
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go17
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go22
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go2
-rw-r--r--pkg/tcpip/network/ip_test.go21
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go27
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go4
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go23
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go121
-rw-r--r--pkg/tcpip/network/ipv4/stats.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go140
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go52
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go66
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/mld.go6
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go10
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go68
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go50
-rw-r--r--pkg/tcpip/network/multicast_group_test.go14
-rw-r--r--pkg/tcpip/ports/BUILD5
-rw-r--r--pkg/tcpip/ports/flags.go150
-rw-r--r--pkg/tcpip/ports/ports.go619
-rw-r--r--pkg/tcpip/ports/ports_test.go59
-rw-r--r--pkg/tcpip/stack/conntrack.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go123
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go1064
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go19
-rw-r--r--pkg/tcpip/stack/nic.go6
-rw-r--r--pkg/tcpip/stack/packet_buffer.go248
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go224
-rw-r--r--pkg/tcpip/stack/registration.go44
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/stack/stack_test.go10
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go6
-rw-r--r--pkg/tcpip/tcpip.go171
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go35
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go78
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go2
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go6
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go46
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go6
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go20
-rw-r--r--pkg/tcpip/transport/tcp/connect.go12
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go76
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment.go8
-rw-r--r--pkg/tcpip/transport/tcp/snd.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go70
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go4
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go12
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go54
-rw-r--r--pkg/tcpip/transport/udp/protocol.go2
117 files changed, 3464 insertions, 1906 deletions
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
index ed3881e27..50e22fe7e 100644
--- a/pkg/abi/linux/ptrace_amd64.go
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -50,3 +50,14 @@ type PtraceRegs struct {
Fs uint64
Gs uint64
}
+
+// InstructionPointer returns the address of the next instruction to
+// be executed.
+func (p *PtraceRegs) InstructionPointer() uint64 {
+ return p.Rip
+}
+
+// StackPointer returns the address of the Stack pointer.
+func (p *PtraceRegs) StackPointer() uint64 {
+ return p.Rsp
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
index 6147738b3..da36811d2 100644
--- a/pkg/abi/linux/ptrace_arm64.go
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -27,3 +27,14 @@ type PtraceRegs struct {
Pc uint64
Pstate uint64
}
+
+// InstructionPointer returns the address of the next instruction to be
+// executed.
+func (p *PtraceRegs) InstructionPointer() uint64 {
+ return p.Pc
+}
+
+// StackPointer returns the address of the Stack pointer.
+func (p *PtraceRegs) StackPointer() uint64 {
+ return p.Sp
+}
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index 23ef7ea8d..3ed6aba5c 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -18,7 +18,6 @@ go_library(
deps = [
"//pkg/linewriter",
"//pkg/sync",
- "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/log/log.go b/pkg/log/log.go
index d39af3bf4..073cf6238 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -40,7 +40,6 @@ import (
"sync/atomic"
"time"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/linewriter"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -105,7 +104,7 @@ func (l *Writer) Write(data []byte) (int, error) {
n += w
// Is it a non-blocking socket?
- if pathErr, ok := err.(*os.PathError); ok && pathErr.Err == unix.EAGAIN {
+ if pathErr, ok := err.(*os.PathError); ok && pathErr.Timeout() {
runtime.Gosched()
continue
}
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index 8e0aa9019..58deb25fc 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -303,17 +303,18 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
// Take a reference on the upper Inode (transferred to
// next.Inode.overlay.upper) and make new translations use it.
- next.Inode.overlay.dataMu.Lock()
+ overlay := next.Inode.overlay
+ overlay.dataMu.Lock()
childUpperInode.IncRef()
- next.Inode.overlay.upper = childUpperInode
- next.Inode.overlay.dataMu.Unlock()
+ overlay.upper = childUpperInode
+ overlay.dataMu.Unlock()
// Invalidate existing translations through the lower Inode.
- next.Inode.overlay.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ overlay.mappings.InvalidateAll(memmap.InvalidateOpts{})
// Remove existing memory mappings from the lower Inode.
if lowerMappable != nil {
- for seg := next.Inode.overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for seg := overlay.mappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
for m := range seg.Value() {
lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
index 141e3c27f..e2af1d2ae 100644
--- a/pkg/sentry/fs/gofer/inode_state.go
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -109,6 +109,7 @@ func (i *inodeFileState) loadLoading(_ struct{}) {
}
// afterLoad is invoked by stateify.
+// +checklocks:i.loading
func (i *inodeFileState) afterLoad() {
load := func() (err error) {
// See comment on i.loading().
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 52061175f..bbe282c03 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -17,6 +17,7 @@ package proc
import (
"fmt"
"io"
+ "math"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -498,6 +500,120 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO
return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled)
}
+// portRangeInode implements fs.InodeOperations. It provides and allows
+// modification of the range of ephemeral ports that IPv4 and IPv6 sockets
+// choose from.
+//
+// +stateify savable
+type portRangeInode struct {
+ fsutil.SimpleFileInode
+
+ stack inet.Stack `state:"wait"`
+
+ // start and end store the port range. We must save/restore this here,
+ // since a netstack instance is created on restore.
+ start *uint16
+ end *uint16
+}
+
+func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ipf := &portRangeInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, ipf, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*portRangeInode) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// +stateify savable
+type portRangeFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ inode *portRangeInode
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (in *portRangeInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &portRangeFile{
+ inode: in,
+ }), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (pf *portRangeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ if pf.inode.start == nil {
+ start, end := pf.inode.stack.PortRange()
+ pf.inode.start = &start
+ pf.inode.end = &end
+ }
+
+ contents := fmt.Sprintf("%d %d\n", *pf.inode.start, *pf.inode.end)
+ n, err := dst.CopyOut(ctx, []byte(contents))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+//
+// Offset is ignored, multiple writes are not supported.
+func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Only consider size of one memory page for input for performance
+ // reasons.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ ports := make([]int32, 2)
+ n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ // Port numbers must be uint16s.
+ if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 {
+ return 0, syserror.EINVAL
+ }
+
+ if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil {
+ return 0, err
+ }
+ if pf.inode.start == nil {
+ pf.inode.start = new(uint16)
+ pf.inode.end = new(uint16)
+ }
+ *pf.inode.start = uint16(ports[0])
+ *pf.inode.end = uint16(ports[1])
+ return n, nil
+}
+
func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
contents := map[string]*fs.Inode{
// Add tcp_sack.
@@ -506,12 +622,15 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
// Add ip_forward.
"ip_forward": newIPForwardingInode(ctx, msrc, s),
+ // Allow for configurable ephemeral port ranges. Note that this
+ // controls ports for both IPv4 and IPv6 sockets.
+ "ip_local_port_range": newPortRangeInode(ctx, msrc, s),
+
// The following files are simple stubs until they are
// implemented in netstack, most of these files are
// configuration related. We use the value closest to the
// actual netstack behavior or any empty file, all of these
// files will have mode 0444 (read-only for all users).
- "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")),
"ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")),
"ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")),
"ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")),
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index d8c237753..e75954105 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -137,6 +137,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// rootInode is the root directory inode for the devpts mounts.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 917f1873d..d4fc484a2 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -548,3 +548,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.mu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 204d8d143..fef857afb 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -47,19 +47,14 @@ type FilesystemType struct{}
// +stateify savable
type filesystemOptions struct {
- // userID specifies the numeric uid of the mount owner.
- // This option should not be specified by the filesystem owner.
- // It is set by libfuse (or, if libfuse is not used, must be set
- // by the filesystem itself). For more information, see man page
- // for fuse(8)
- userID uint32
-
- // groupID specifies the numeric gid of the mount owner.
- // This option should not be specified by the filesystem owner.
- // It is set by libfuse (or, if libfuse is not used, must be set
- // by the filesystem itself). For more information, see man page
- // for fuse(8)
- groupID uint32
+ // mopts contains the raw, unparsed mount options passed to this filesystem.
+ mopts string
+
+ // uid of the mount owner.
+ uid auth.KUID
+
+ // gid of the mount owner.
+ gid auth.KGID
// rootMode specifies the the file mode of the filesystem's root.
rootMode linux.FileMode
@@ -73,6 +68,19 @@ type filesystemOptions struct {
// specified as "max_read" in fs parameters.
// If not specified by user, use math.MaxUint32 as default value.
maxRead uint32
+
+ // defaultPermissions is the default_permissions mount option. It instructs
+ // the kernel to perform a standard unix permission checks based on
+ // ownership and mode bits, instead of deferring the check to the server.
+ //
+ // Immutable after mount.
+ defaultPermissions bool
+
+ // allowOther is the allow_other mount option. It allows processes that
+ // don't own the FUSE mount to call into it.
+ //
+ // Immutable after mount.
+ allowOther bool
}
// filesystem implements vfs.FilesystemImpl.
@@ -108,18 +116,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
- var fsopts filesystemOptions
+ fsopts := filesystemOptions{mopts: opts.Data}
mopts := vfs.GenericParseMountOptions(opts.Data)
deviceDescriptorStr, ok := mopts["fd"]
if !ok {
- log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option fd missing")
return nil, nil, syserror.EINVAL
}
delete(mopts, "fd")
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
- log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err)
+ ctx.Debugf("fusefs.FilesystemType.GetFilesystem: invalid fd: %q (%v)", deviceDescriptorStr, err)
return nil, nil, syserror.EINVAL
}
@@ -141,38 +149,54 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Parse and set all the other supported FUSE mount options.
// TODO(gVisor.dev/issue/3229): Expand the supported mount options.
- if userIDStr, ok := mopts["user_id"]; ok {
+ if uidStr, ok := mopts["user_id"]; ok {
delete(mopts, "user_id")
- userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ uid, err := strconv.ParseUint(uidStr, 10, 32)
if err != nil {
- log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), uidStr)
return nil, nil, syserror.EINVAL
}
- fsopts.userID = uint32(userID)
+ kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
+ if !kuid.Ok() {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.uid = kuid
+ } else {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option user_id missing")
+ return nil, nil, syserror.EINVAL
}
- if groupIDStr, ok := mopts["group_id"]; ok {
+ if gidStr, ok := mopts["group_id"]; ok {
delete(mopts, "group_id")
- groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ gid, err := strconv.ParseUint(gidStr, 10, 32)
if err != nil {
- log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), gidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
+ if !kgid.Ok() {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
return nil, nil, syserror.EINVAL
}
- fsopts.groupID = uint32(groupID)
+ fsopts.gid = kgid
+ } else {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option group_id missing")
+ return nil, nil, syserror.EINVAL
}
- rootMode := linux.FileMode(0777)
- modeStr, ok := mopts["rootmode"]
- if ok {
+ if modeStr, ok := mopts["rootmode"]; ok {
delete(mopts, "rootmode")
mode, err := strconv.ParseUint(modeStr, 8, 32)
if err != nil {
log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
return nil, nil, syserror.EINVAL
}
- rootMode = linux.FileMode(mode)
+ fsopts.rootMode = linux.FileMode(mode)
+ } else {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option rootmode missing")
+ return nil, nil, syserror.EINVAL
}
- fsopts.rootMode = rootMode
// Set the maxInFlightRequests option.
fsopts.maxActiveRequests = maxActiveRequestsDefault
@@ -192,6 +216,16 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.maxRead = math.MaxUint32
}
+ if _, ok := mopts["default_permissions"]; ok {
+ delete(mopts, "default_permissions")
+ fsopts.defaultPermissions = true
+ }
+
+ if _, ok := mopts["allow_other"]; ok {
+ delete(mopts, "allow_other")
+ fsopts.allowOther = true
+ }
+
// Check for unparsed options.
if len(mopts) != 0 {
log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts)
@@ -260,6 +294,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fs.opts.mopts
+}
+
// inode implements kernfs.Inode.
//
// +stateify savable
@@ -318,6 +357,37 @@ func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FU
return i
}
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // Since FUSE operations are ultimately backed by a userspace process (the
+ // fuse daemon), allowing a process to call into fusefs grants the daemon
+ // ptrace-like capabilities over the calling process. Because of this, by
+ // default FUSE only allows the mount owner to interact with the
+ // filesystem. This explicitly excludes setuid/setgid processes.
+ //
+ // This behaviour can be overriden with the 'allow_other' mount option.
+ //
+ // See fs/fuse/dir.c:fuse_allow_current_process() in Linux.
+ if !i.fs.opts.allowOther {
+ if creds.RealKUID != i.fs.opts.uid ||
+ creds.EffectiveKUID != i.fs.opts.uid ||
+ creds.SavedKUID != i.fs.opts.uid ||
+ creds.RealKGID != i.fs.opts.gid ||
+ creds.EffectiveKGID != i.fs.opts.gid ||
+ creds.SavedKGID != i.fs.opts.gid {
+ return syserror.EACCES
+ }
+ }
+
+ // By default, fusefs delegates all permission checks to the server.
+ // However, standard unix permission checks can be enabled with the
+ // default_permissions mount option.
+ if i.fs.opts.defaultPermissions {
+ return i.InodeAttrs.CheckPermissions(ctx, creds, ats)
+ }
+ return nil
+}
+
// Open implements kernfs.Inode.Open.
func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
isDir := i.InodeAttrs.Mode().IsDir()
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 8f95473b6..c34451269 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -15,7 +15,9 @@
package gofer
import (
+ "fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
@@ -1608,3 +1610,58 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+type mopt struct {
+ key string
+ value interface{}
+}
+
+func (m mopt) String() string {
+ if m.value == nil {
+ return fmt.Sprintf("%s", m.key)
+ }
+ return fmt.Sprintf("%s=%v", m.key, m.value)
+}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ optsKV := []mopt{
+ {moptTransport, transportModeFD}, // Only valid value, currently.
+ {moptReadFD, fs.opts.fd}, // Currently, read and write FD are the same.
+ {moptWriteFD, fs.opts.fd}, // Currently, read and write FD are the same.
+ {moptAname, fs.opts.aname},
+ {moptDfltUID, fs.opts.dfltuid},
+ {moptDfltGID, fs.opts.dfltgid},
+ {moptMsize, fs.opts.msize},
+ {moptVersion, fs.opts.version},
+ {moptDentryCacheLimit, fs.opts.maxCachedDentries},
+ }
+
+ switch fs.opts.interop {
+ case InteropModeExclusive:
+ optsKV = append(optsKV, mopt{moptCache, cacheFSCache})
+ case InteropModeWritethrough:
+ optsKV = append(optsKV, mopt{moptCache, cacheFSCacheWritethrough})
+ case InteropModeShared:
+ if fs.opts.regularFilesUseSpecialFileFD {
+ optsKV = append(optsKV, mopt{moptCache, cacheNone})
+ } else {
+ optsKV = append(optsKV, mopt{moptCache, cacheRemoteRevalidating})
+ }
+ }
+ if fs.opts.forcePageCache {
+ optsKV = append(optsKV, mopt{moptForcePageCache, nil})
+ }
+ if fs.opts.limitHostFDTranslation {
+ optsKV = append(optsKV, mopt{moptLimitHostFDTranslation, nil})
+ }
+ if fs.opts.overlayfsStaleRead {
+ optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil})
+ }
+
+ opts := make([]string, 0, len(optsKV))
+ for _, opt := range optsKV {
+ opts = append(opts, opt.String())
+ }
+ return strings.Join(opts, ",")
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 1508cbdf1..71569dc65 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -66,6 +66,34 @@ import (
// Name is the default filesystem name.
const Name = "9p"
+// Mount option names for goferfs.
+const (
+ moptTransport = "trans"
+ moptReadFD = "rfdno"
+ moptWriteFD = "wfdno"
+ moptAname = "aname"
+ moptDfltUID = "dfltuid"
+ moptDfltGID = "dfltgid"
+ moptMsize = "msize"
+ moptVersion = "version"
+ moptDentryCacheLimit = "dentry_cache_limit"
+ moptCache = "cache"
+ moptForcePageCache = "force_page_cache"
+ moptLimitHostFDTranslation = "limit_host_fd_translation"
+ moptOverlayfsStaleRead = "overlayfs_stale_read"
+)
+
+// Valid values for the "cache" mount option.
+const (
+ cacheNone = "none"
+ cacheFSCache = "fscache"
+ cacheFSCacheWritethrough = "fscache_writethrough"
+ cacheRemoteRevalidating = "remote_revalidating"
+)
+
+// Valid values for "trans" mount option.
+const transportModeFD = "fd"
+
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
@@ -301,39 +329,39 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Get the attach name.
fsopts.aname = "/"
- if aname, ok := mopts["aname"]; ok {
- delete(mopts, "aname")
+ if aname, ok := mopts[moptAname]; ok {
+ delete(mopts, moptAname)
fsopts.aname = aname
}
// Parse the cache policy. For historical reasons, this defaults to the
// least generally-applicable option, InteropModeExclusive.
fsopts.interop = InteropModeExclusive
- if cache, ok := mopts["cache"]; ok {
- delete(mopts, "cache")
+ if cache, ok := mopts[moptCache]; ok {
+ delete(mopts, moptCache)
switch cache {
- case "fscache":
+ case cacheFSCache:
fsopts.interop = InteropModeExclusive
- case "fscache_writethrough":
+ case cacheFSCacheWritethrough:
fsopts.interop = InteropModeWritethrough
- case "none":
+ case cacheNone:
fsopts.regularFilesUseSpecialFileFD = true
fallthrough
- case "remote_revalidating":
+ case cacheRemoteRevalidating:
fsopts.interop = InteropModeShared
default:
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: %s=%s", moptCache, cache)
return nil, nil, syserror.EINVAL
}
}
// Parse the default UID and GID.
fsopts.dfltuid = _V9FS_DEFUID
- if dfltuidstr, ok := mopts["dfltuid"]; ok {
- delete(mopts, "dfltuid")
+ if dfltuidstr, ok := mopts[moptDfltUID]; ok {
+ delete(mopts, moptDfltUID)
dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltUID, dfltuidstr)
return nil, nil, syserror.EINVAL
}
// In Linux, dfltuid is interpreted as a UID and is converted to a KUID
@@ -342,11 +370,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.dfltuid = auth.KUID(dfltuid)
}
fsopts.dfltgid = _V9FS_DEFGID
- if dfltgidstr, ok := mopts["dfltgid"]; ok {
- delete(mopts, "dfltgid")
+ if dfltgidstr, ok := mopts[moptDfltGID]; ok {
+ delete(mopts, moptDfltGID)
dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltGID, dfltgidstr)
return nil, nil, syserror.EINVAL
}
fsopts.dfltgid = auth.KGID(dfltgid)
@@ -354,11 +382,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Parse the 9P message size.
fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
- if msizestr, ok := mopts["msize"]; ok {
- delete(mopts, "msize")
+ if msizestr, ok := mopts[moptMsize]; ok {
+ delete(mopts, moptMsize)
msize, err := strconv.ParseUint(msizestr, 10, 32)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: %s=%s", moptMsize, msizestr)
return nil, nil, syserror.EINVAL
}
fsopts.msize = uint32(msize)
@@ -366,34 +394,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Parse the 9P protocol version.
fsopts.version = p9.HighestVersionString()
- if version, ok := mopts["version"]; ok {
- delete(mopts, "version")
+ if version, ok := mopts[moptVersion]; ok {
+ delete(mopts, moptVersion)
fsopts.version = version
}
// Parse the dentry cache limit.
fsopts.maxCachedDentries = 1000
- if str, ok := mopts["dentry_cache_limit"]; ok {
- delete(mopts, "dentry_cache_limit")
+ if str, ok := mopts[moptDentryCacheLimit]; ok {
+ delete(mopts, moptDentryCacheLimit)
maxCachedDentries, err := strconv.ParseUint(str, 10, 64)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str)
return nil, nil, syserror.EINVAL
}
fsopts.maxCachedDentries = maxCachedDentries
}
// Handle simple flags.
- if _, ok := mopts["force_page_cache"]; ok {
- delete(mopts, "force_page_cache")
+ if _, ok := mopts[moptForcePageCache]; ok {
+ delete(mopts, moptForcePageCache)
fsopts.forcePageCache = true
}
- if _, ok := mopts["limit_host_fd_translation"]; ok {
- delete(mopts, "limit_host_fd_translation")
+ if _, ok := mopts[moptLimitHostFDTranslation]; ok {
+ delete(mopts, moptLimitHostFDTranslation)
fsopts.limitHostFDTranslation = true
}
- if _, ok := mopts["overlayfs_stale_read"]; ok {
- delete(mopts, "overlayfs_stale_read")
+ if _, ok := mopts[moptOverlayfsStaleRead]; ok {
+ delete(mopts, moptOverlayfsStaleRead)
fsopts.overlayfsStaleRead = true
}
// fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
@@ -469,34 +497,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
// Check that the transport is "fd".
- trans, ok := mopts["trans"]
- if !ok || trans != "fd" {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'")
+ trans, ok := mopts[moptTransport]
+ if !ok || trans != transportModeFD {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as '%s=%s'", moptTransport, transportModeFD)
return -1, syserror.EINVAL
}
- delete(mopts, "trans")
+ delete(mopts, moptTransport)
// Check that read and write FDs are provided and identical.
- rfdstr, ok := mopts["rfdno"]
+ rfdstr, ok := mopts[moptReadFD]
if !ok {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'")
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as '%s=<file descriptor>'", moptReadFD)
return -1, syserror.EINVAL
}
- delete(mopts, "rfdno")
+ delete(mopts, moptReadFD)
rfd, err := strconv.Atoi(rfdstr)
if err != nil {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr)
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: %s=%s", moptReadFD, rfdstr)
return -1, syserror.EINVAL
}
- wfdstr, ok := mopts["wfdno"]
+ wfdstr, ok := mopts[moptWriteFD]
if !ok {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'")
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as '%s=<file descriptor>'", moptWriteFD)
return -1, syserror.EINVAL
}
- delete(mopts, "wfdno")
+ delete(mopts, moptWriteFD)
wfd, err := strconv.Atoi(wfdstr)
if err != nil {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr)
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: %s=%s", moptWriteFD, wfdstr)
return -1, syserror.EINVAL
}
if rfd != wfd {
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index ad5de80dc..b9cce4181 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -260,6 +260,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
var s unix.Stat_t
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index e63588e33..1cd3137e6 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -67,6 +67,11 @@ type filesystem struct {
kernfs.Filesystem
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
type file struct {
kernfs.DynamicBytesFile
content string
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index f7f795b10..84e37f793 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -85,6 +85,8 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+//
+// +checklocks:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
@@ -110,6 +112,7 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*
putDentrySlice(*dsp)
}
+// +checklocks:fs.renameMu
func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
if *ds == nil {
fs.renameMu.Unlock()
@@ -1761,3 +1764,15 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ // Return the mount options from the topmost layer.
+ var vd vfs.VirtualDentry
+ if fs.opts.UpperRoot.Ok() {
+ vd = fs.opts.UpperRoot
+ } else {
+ vd = fs.opts.LowerRoots[0]
+ }
+ return vd.Mount().Filesystem().Impl().MountOptions()
+}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 429733c10..3f05e444e 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -80,6 +80,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// inode implements kernfs.Inode.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 8716d0a3c..254a8b062 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -104,6 +104,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries)
+}
+
// dynamicInode is an overfitted interface for common Inodes with
// dynamicByteSource types used in procfs.
//
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index fd7823daa..fb274b78e 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -17,6 +17,7 @@ package proc
import (
"bytes"
"fmt"
+ "math"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -69,17 +70,17 @@ func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials,
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]kernfs.Inode{
"ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
- "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
+ "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
+ "ip_local_port_range": fs.newInode(ctx, root, 0644, &portRange{stack: stack}),
+ "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")),
"ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")),
"ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")),
"ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")),
@@ -421,3 +422,68 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs
}
return n, nil
}
+
+// portRange implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/ipv4/ip_local_port_range.
+//
+// +stateify savable
+type portRange struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+
+ // start and end store the port range. We must save/restore this here,
+ // since a netstack instance is created on restore.
+ start *uint16
+ end *uint16
+}
+
+var _ vfs.WritableDynamicBytesSource = (*portRange)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if pr.start == nil {
+ start, end := pr.stack.PortRange()
+ pr.start = &start
+ pr.end = &end
+ }
+ _, err := fmt.Fprintf(buf, "%d %d\n", *pr.start, *pr.end)
+ return err
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is
+ // large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ ports := make([]int32, 2)
+ n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ // Port numbers must be uint16s.
+ if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 {
+ return 0, syserror.EINVAL
+ }
+
+ if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil {
+ return 0, err
+ }
+ if pr.start == nil {
+ pr.start = new(uint16)
+ pr.end = new(uint16)
+ }
+ *pr.start = uint16(ports[0])
+ *pr.end = uint16(ports[1])
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index fda1fa942..735756280 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -85,6 +85,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// inode implements kernfs.Inode.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index dbd9ebdda..1d9280dae 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -143,6 +143,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries)
+}
+
// dir implements kernfs.Inode.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 4f675c21e..5fdca1d46 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -898,3 +898,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
d = d.parent
}
}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fs.mopts
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index a01e413e0..8df81f589 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -70,6 +70,10 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
+ // mopts contains the tmpfs-specific mount options passed to this
+ // filesystem. Immutable.
+ mopts string
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex `state:"nosave"`
@@ -184,6 +188,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mfp: mfp,
clock: clock,
devMinor: devMinor,
+ mopts: opts.Data,
}
fs.vfsfs.Init(vfsObj, newFSType, &fs)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 9057d2b4e..6cb1a23e0 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -590,6 +590,23 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
return nil, err
}
+ // Clear the Merkle tree file if they are to be generated at runtime.
+ // TODO(b/182315468): Optimize the Merkle tree generate process to
+ // allow only updating certain files/directories.
+ if fs.allowRuntimeEnable {
+ childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childMerkleVD,
+ Start: childMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_TRUNC,
+ Mode: 0644,
+ })
+ if err != nil {
+ return nil, err
+ }
+ childMerkleFD.DecRef(ctx)
+ }
+
// The dentry needs to be cleaned up if any error occurs. IncRef will be
// called if a verity child dentry is successfully created.
defer childMerkleVD.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 374f71568..0d9b0ee2c 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -38,6 +38,7 @@ import (
"fmt"
"math"
"strconv"
+ "strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -310,6 +311,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
d.DecRef(ctx)
return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
}
+
+ // Clear the Merkle tree file if they are to be generated at runtime.
+ // TODO(b/182315468): Optimize the Merkle tree generate process to
+ // allow only updating certain files/directories.
+ if fs.allowRuntimeEnable {
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_TRUNC,
+ Mode: 0644,
+ })
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerMerkleFD.DecRef(ctx)
+ }
+
d.lowerMerkleVD = lowerMerkleVD
// Get metadata from the underlying file system.
@@ -418,6 +437,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.lowerMount.DecRef(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -750,6 +774,50 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return syserror.EPERM
}
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ if !fd.d.isDir() {
+ return syserror.ENOTDIR
+ }
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ var ds []vfs.Dirent
+ err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ // Do not include the Merkle tree files.
+ if strings.Contains(dirent.Name, merklePrefix) || strings.Contains(dirent.Name, merkleRootPrefix) {
+ return nil
+ }
+ if fd.d.verityEnabled() {
+ // Verify that the child is expected.
+ if dirent.Name != "." && dirent.Name != ".." {
+ if _, ok := fd.d.childrenNames[dirent.Name]; !ok {
+ return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name))
+ }
+ }
+ }
+ ds = append(ds, dirent)
+ return nil
+ }))
+
+ if err != nil {
+ return err
+ }
+
+ // The result should contain all children plus "." and "..".
+ if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 {
+ return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
+ }
+
+ for fd.off < int64(len(ds)) {
+ if err := cb.Handle(ds[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index f31277d30..6b71bd3a9 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -93,6 +93,14 @@ type Stack interface {
// SetForwarding enables or disables packet forwarding between NICs.
SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error
+
+ // PortRange returns the UDP and TCP inclusive range of ephemeral ports
+ // used in both IPv4 and IPv6.
+ PortRange() (uint16, uint16)
+
+ // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
+ // (inclusive).
+ SetPortRange(start uint16, end uint16) error
}
// Interface contains information about a network interface.
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 9ebeba8a3..03e2608c2 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -164,3 +164,15 @@ func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable b
s.IPForwarding = enable
return nil
}
+
+// PortRange implements inet.Stack.PortRange.
+func (*TestStack) PortRange() (uint16, uint16) {
+ // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net().
+ return 32768, 28232
+}
+
+// SetPortRange implements inet.Stack.SetPortRange.
+func (*TestStack) SetPortRange(start uint16, end uint16) error {
+ // No-op.
+ return nil
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 58cc11a13..a4af3e21b 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -876,6 +876,7 @@ func (f *MemoryFile) UpdateUsage() error {
// in bs, sets committed[i] to 1 if the page is committed and 0 otherwise.
//
// Precondition: f.mu must be held; it may be unlocked and reacquired.
+// +checklocks:f.mu
func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error {
// Track if anything changed to elide the merge. In the common case, we
// expect all segments to be committed and no merge to occur.
@@ -925,72 +926,73 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
r := seg.Range()
var checkErr error
- err := f.forEachMappingSlice(r, func(s []byte) {
- if checkErr != nil {
- return
- }
-
- // Ensure that we have sufficient buffer for the call
- // (one byte per page). The length of each slice must
- // be page-aligned.
- bufLen := len(s) / usermem.PageSize
- if len(buf) < bufLen {
- buf = make([]byte, bufLen)
- }
+ err := f.forEachMappingSlice(r,
+ func(s []byte) {
+ if checkErr != nil {
+ return
+ }
- // Query for new pages in core.
- // NOTE(b/165896008): mincore (which is passed as checkCommitted)
- // by f.UpdateUsage() might take a really long time. So unlock f.mu
- // while checkCommitted runs.
- f.mu.Unlock()
- err := checkCommitted(s, buf)
- f.mu.Lock()
- if err != nil {
- checkErr = err
- return
- }
+ // Ensure that we have sufficient buffer for the call
+ // (one byte per page). The length of each slice must
+ // be page-aligned.
+ bufLen := len(s) / usermem.PageSize
+ if len(buf) < bufLen {
+ buf = make([]byte, bufLen)
+ }
- // Scan each page and switch out segments.
- seg := f.usage.LowerBoundSegment(r.Start)
- for i := 0; i < bufLen; {
- if buf[i]&0x1 == 0 {
- i++
- continue
+ // Query for new pages in core.
+ // NOTE(b/165896008): mincore (which is passed as checkCommitted)
+ // by f.UpdateUsage() might take a really long time. So unlock f.mu
+ // while checkCommitted runs.
+ f.mu.Unlock()
+ err := checkCommitted(s, buf)
+ f.mu.Lock()
+ if err != nil {
+ checkErr = err
+ return
}
- // Scan to the end of this committed range.
- j := i + 1
- for ; j < bufLen; j++ {
- if buf[j]&0x1 == 0 {
- break
+
+ // Scan each page and switch out segments.
+ seg := f.usage.LowerBoundSegment(r.Start)
+ for i := 0; i < bufLen; {
+ if buf[i]&0x1 == 0 {
+ i++
+ continue
}
- }
- committedFR := memmap.FileRange{
- Start: r.Start + uint64(i*usermem.PageSize),
- End: r.Start + uint64(j*usermem.PageSize),
- }
- // Advance seg to committedFR.Start.
- for seg.Ok() && seg.End() < committedFR.Start {
- seg = seg.NextSegment()
- }
- // Mark pages overlapping committedFR as committed.
- for seg.Ok() && seg.Start() < committedFR.End {
- if seg.ValuePtr().canCommit() {
- seg = f.usage.Isolate(seg, committedFR)
- seg.ValuePtr().knownCommitted = true
- amount := seg.Range().Length()
- usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind)
- f.usageExpected += amount
- changedAny = true
+ // Scan to the end of this committed range.
+ j := i + 1
+ for ; j < bufLen; j++ {
+ if buf[j]&0x1 == 0 {
+ break
+ }
}
- seg = seg.NextSegment()
+ committedFR := memmap.FileRange{
+ Start: r.Start + uint64(i*usermem.PageSize),
+ End: r.Start + uint64(j*usermem.PageSize),
+ }
+ // Advance seg to committedFR.Start.
+ for seg.Ok() && seg.End() < committedFR.Start {
+ seg = seg.NextSegment()
+ }
+ // Mark pages overlapping committedFR as committed.
+ for seg.Ok() && seg.Start() < committedFR.End {
+ if seg.ValuePtr().canCommit() {
+ seg = f.usage.Isolate(seg, committedFR)
+ seg.ValuePtr().knownCommitted = true
+ amount := seg.Range().Length()
+ usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind)
+ f.usageExpected += amount
+ changedAny = true
+ }
+ seg = seg.NextSegment()
+ }
+ // Continue scanning for committed pages.
+ i = j + 1
}
- // Continue scanning for committed pages.
- i = j + 1
- }
- // Advance r.Start.
- r.Start += uint64(len(s))
- })
+ // Advance r.Start.
+ r.Start += uint64(len(s))
+ })
if checkErr != nil {
return checkErr
}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index e6323244c..5bcf92e14 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -504,3 +504,14 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
}
+
+// PortRange implements inet.Stack.PortRange.
+func (*Stack) PortRange() (uint16, uint16) {
+ // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net().
+ return 32768, 28232
+}
+
+// SetPortRange implements inet.Stack.SetPortRange.
+func (*Stack) SetPortRange(start uint16, end uint16) error {
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f2dc7c90b..9efb195f0 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -83,110 +83,121 @@ var Metrics = tcpip.Stats{
V4: tcpip.ICMPv4Stats{
PacketsSent: tcpip.ICMPv4SentPacketStats{
ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."),
},
PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."),
},
},
V6: tcpip.ICMPv6Stats{
PacketsSent: tcpip.ICMPv6SentPacketStats{
ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."),
+ MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."),
+ MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."),
+ MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."),
},
PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."),
+ MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."),
+ MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."),
+ MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."),
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."),
+ RouterOnlyPacketsDroppedByHost: mustCreateMetric("/netstack/icmp/v6/packets_received/router_only_packets_dropped_by_host", "Number of ICMPv6 packets dropped due to being router-specific packets."),
},
},
},
IGMP: tcpip.IGMPStats{
PacketsSent: tcpip.IGMPSentPacketStats{
IGMPPacketStats: tcpip.IGMPPacketStats{
- MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."),
- V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."),
- V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."),
- LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."),
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."),
},
PacketsReceived: tcpip.IGMPReceivedPacketStats{
IGMPPacketStats: tcpip.IGMPPacketStats{
- MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."),
- V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."),
- V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."),
- LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."),
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."),
- ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."),
- Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."),
+ ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."),
},
},
IP: tcpip.IPStats{
- PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
- InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
- InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."),
- PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
- PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
- OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
- MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
- MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
- IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."),
- IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."),
- IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."),
- OptionTimestampReceived: mustCreateMetric("/netstack/ip/options/timestamp_received", "Total number of timestamp options found in received IP packets."),
- OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Total number of record route options found in received IP packets."),
- OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Total number of router alert options found in received IP packets."),
- OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Total number of unknown options found in received IP packets."),
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ DisabledPacketsReceived: mustCreateMetric("/netstack/ip/disabled_packets_received", "Number of IP packets received from the link layer when the IP layer is disabled."),
+ InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Number of IP packets received with an unknown or invalid destination address."),
+ InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Number of IP packets received with an unknown or invalid source address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Number of IP packets which failed to write to a link-layer endpoint."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Number of IP packets which failed IP header validation checks."),
+ MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Number of IP fragments which failed IP fragment validation checks."),
+ IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Number of IP packets dropped in the Prerouting chain."),
+ IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Number of IP packets dropped in the Input chain."),
+ IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Number of IP packets dropped in the Output chain."),
+ OptionTimestampReceived: mustCreateMetric("/netstack/ip/options/timestamp_received", "Number of timestamp options found in received IP packets."),
+ OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."),
+ OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."),
+ OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."),
},
ARP: tcpip.ARPStats{
PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."),
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index cc0fadeb5..b215067cf 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -336,7 +336,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.ParamProblem.Value(), // InParmProbs.
in.SrcQuench.Value(), // InSrcQuenchs.
in.Redirect.Value(), // InRedirects.
- in.Echo.Value(), // InEchos.
+ in.EchoRequest.Value(), // InEchos.
in.EchoReply.Value(), // InEchoReps.
in.Timestamp.Value(), // InTimestamps.
in.TimestampReply.Value(), // InTimestampReps.
@@ -349,7 +349,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
out.ParamProblem.Value(), // OutParmProbs.
out.SrcQuench.Value(), // OutSrcQuenchs.
out.Redirect.Value(), // OutRedirects.
- out.Echo.Value(), // OutEchos.
+ out.EchoRequest.Value(), // OutEchos.
out.EchoReply.Value(), // OutEchoReps.
out.Timestamp.Value(), // OutTimestamps.
out.TimestampReply.Value(), // OutTimestampReps.
@@ -478,3 +478,13 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool)
}
return nil
}
+
+// PortRange implements inet.Stack.PortRange.
+func (s *Stack) PortRange() (uint16, uint16) {
+ return s.Stack.PortRange()
+}
+
+// SetPortRange implements inet.Stack.SetPortRange.
+func (s *Stack) SetPortRange(start uint16, end uint16) error {
+ return syserr.TranslateNetstackError(s.Stack.SetPortRange(start, end)).ToError()
+}
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 7ad0eaf86..3caf417ca 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -291,6 +291,11 @@ func (fs *anonFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDe
return PrependPathSyntheticError{}
}
+// MountOptions implements FilesystemImpl.MountOptions.
+func (fs *anonFilesystem) MountOptions() string {
+ return ""
+}
+
// IncRef implements DentryImpl.IncRef.
func (d *anonDentry) IncRef() {
// no-op
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 320ab7ce1..e7ca24d96 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -211,12 +211,14 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent
// AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion
// fails.
+// +checklocks:d.mu
func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) {
d.mu.Unlock()
}
// CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion
// succeeds.
+// +checklocks:d.mu
func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) {
d.dead = true
d.mu.Unlock()
@@ -270,6 +272,8 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t
// AbortRenameDentry must be called after PrepareRenameDentry if the rename
// fails.
+// +checklocks:from.mu
+// +checklocks:to.mu
func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
from.mu.Unlock()
if to != nil {
@@ -282,6 +286,8 @@ func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
// that was replaced by from.
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
+// +checklocks:from.mu
+// +checklocks:to.mu
func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) {
from.mu.Unlock()
if to != nil {
@@ -297,6 +303,8 @@ func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, fro
// from and to are exchanged by rename(RENAME_EXCHANGE).
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
+// +checklocks:from.mu
+// +checklocks:to.mu
func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) {
from.mu.Unlock()
to.mu.Unlock()
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 2c4b81e78..059939010 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -502,6 +502,15 @@ type FilesystemImpl interface {
//
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
+
+ // MountOptions returns mount options for the current filesystem. This
+ // should only return options specific to the filesystem (i.e. don't return
+ // "ro", "rw", etc). Options should be returned as a comma-separated string,
+ // similar to the input to the 5th argument to mount.
+ //
+ // If the implementation has no filesystem-specific options, it should
+ // return the empty string.
+ MountOptions() string
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 7063066ff..922f9e697 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -217,20 +217,21 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
return err
}
vfs.mountMu.Lock()
- vd.dentry.mu.Lock()
+ vdDentry := vd.dentry
+ vdDentry.mu.Lock()
for {
- if vd.dentry.dead {
- vd.dentry.mu.Unlock()
+ if vdDentry.dead {
+ vdDentry.mu.Unlock()
vfs.mountMu.Unlock()
vd.DecRef(ctx)
return syserror.ENOENT
}
// vd might have been mounted over between vfs.GetDentryAt() and
// vfs.mountMu.Lock().
- if !vd.dentry.isMounted() {
+ if !vdDentry.isMounted() {
break
}
- nextmnt := vfs.mounts.Lookup(vd.mount, vd.dentry)
+ nextmnt := vfs.mounts.Lookup(vd.mount, vdDentry)
if nextmnt == nil {
break
}
@@ -243,13 +244,13 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
}
// This can't fail since we're holding vfs.mountMu.
nextmnt.root.IncRef()
- vd.dentry.mu.Unlock()
+ vdDentry.mu.Unlock()
vd.DecRef(ctx)
vd = VirtualDentry{
mount: nextmnt,
dentry: nextmnt.root,
}
- vd.dentry.mu.Lock()
+ vdDentry.mu.Lock()
}
// TODO(gvisor.dev/issue/1035): Linux requires that either both the mount
// point and the mount root are directories, or neither are, and returns
@@ -258,7 +259,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
vfs.mounts.seq.BeginWrite()
vfs.connectLocked(mnt, vd, mntns)
vfs.mounts.seq.EndWrite()
- vd.dentry.mu.Unlock()
+ vdDentry.mu.Unlock()
vfs.mountMu.Unlock()
return nil
}
@@ -958,13 +959,17 @@ func manglePath(p string) string {
// superBlockOpts returns the super block options string for the the mount at
// the given path.
func superBlockOpts(mountPath string, mnt *Mount) string {
- // gVisor doesn't (yet) have a concept of super block options, so we
- // use the ro/rw bit from the mount flag.
+ // Compose super block options by combining global mount flags with
+ // FS-specific mount options.
opts := "rw"
if mnt.ReadOnly() {
opts = "ro"
}
+ if mopts := mnt.fs.Impl().MountOptions(); mopts != "" {
+ opts += "," + mopts
+ }
+
// NOTE(b/147673608): If the mount is a cgroup, we also need to include
// the cgroup name in the options. For now we just read that from the
// path.
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index 4e49a9b89..411a80a8a 100644
--- a/pkg/sync/mutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
@@ -72,6 +72,7 @@ func (m *Mutex) Lock() {
// Preconditions:
// * m is locked.
// * m was locked by this goroutine.
+// +checklocksignore
func (m *Mutex) Unlock() {
noteUnlock(unsafe.Pointer(m))
m.m.Unlock()
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index 4cf3fcd6e..892d3e641 100644
--- a/pkg/sync/rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -105,6 +105,7 @@ func (rw *CrossGoroutineRWMutex) RUnlock() {
// TryLock locks rw for writing. It returns true if it succeeds and false
// otherwise. It does not block.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) TryLock() bool {
if RaceEnabled {
RaceDisable()
@@ -155,6 +156,7 @@ func (rw *CrossGoroutineRWMutex) Lock() {
//
// Preconditions:
// * rw is locked for writing.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) Unlock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.writerSem))
@@ -181,6 +183,7 @@ func (rw *CrossGoroutineRWMutex) Unlock() {
//
// Preconditions:
// * rw is locked for writing.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) DowngradeLock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.readerSem))
@@ -250,6 +253,7 @@ func (rw *RWMutex) RLock() {
// Preconditions:
// * rw is locked for reading.
// * rw was locked by this goroutine.
+// +checklocksignore
func (rw *RWMutex) RUnlock() {
rw.m.RUnlock()
noteUnlock(unsafe.Pointer(rw))
@@ -279,6 +283,7 @@ func (rw *RWMutex) Lock() {
// Preconditions:
// * rw is locked for writing.
// * rw was locked by this goroutine.
+// +checklocksignore
func (rw *RWMutex) Unlock() {
rw.m.Unlock()
noteUnlock(unsafe.Pointer(rw))
@@ -288,6 +293,7 @@ func (rw *RWMutex) Unlock() {
//
// Preconditions:
// * rw is locked for writing.
+// +checklocksignore
func (rw *RWMutex) DowngradeLock() {
// No note change for DowngradeLock.
rw.m.DowngradeLock()
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 0b9139570..79e564de6 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -51,6 +51,7 @@ var (
ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM)
ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT)
ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL)
+ ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL)
)
// TranslateNetstackError converts an error from the tcpip package to a sentry
@@ -135,6 +136,8 @@ func TranslateNetstackError(err tcpip.Error) *Error {
return ErrBadBuffer
case *tcpip.ErrMalformedHeader:
return ErrMalformedHeader
+ case *tcpip.ErrInvalidPortRange:
+ return ErrInvalidPortRange
default:
panic(fmt.Sprintf("unknown error %T", err))
}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index b05e81526..f4a30effd 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -196,7 +196,7 @@ func (vv *VectorisedView) CapLength(length int) {
// If the buffer argument is large enough to contain all the Views of this
// VectorisedView, the method will avoid allocations and use the buffer to
// store the Views of the clone.
-func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
+func (vv VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
@@ -290,6 +290,14 @@ func (vv *VectorisedView) AppendView(v View) {
vv.size += len(v)
}
+// AppendViews appends views to vv.
+func (vv *VectorisedView) AppendViews(views []View) {
+ vv.views = append(vv.views, views...)
+ for _, v := range views {
+ vv.size += len(v)
+ }
+}
+
// Readers returns a bytes.Reader for each of vv's views.
func (vv *VectorisedView) Readers() []bytes.Reader {
readers := make([]bytes.Reader, 0, len(vv.views))
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index 78b2faa26..d296d9c2b 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -45,6 +45,11 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
return buffer.NewVectorisedView(size, views)
}
+// v returns a buffer.View containing piece.
+func v(piece string) buffer.View {
+ return buffer.View(piece)
+}
+
var capLengthTestCases = []struct {
comment string
in buffer.VectorisedView
@@ -125,6 +130,12 @@ var trimFrontTestCases = []struct {
want: vv(1, "3"),
},
{
+ comment: "Case with one empty Views",
+ in: vv(3, "1", "", "23"),
+ count: 2,
+ want: vv(1, "3"),
+ },
+ {
comment: "Corner case with negative count",
in: vv(1, "1"),
count: -1,
@@ -566,11 +577,11 @@ func TestAppendView(t *testing.T) {
in buffer.View
want buffer.VectorisedView
}{
- {buffer.VectorisedView{}, nil, buffer.VectorisedView{}},
- {buffer.VectorisedView{}, buffer.View{}, buffer.VectorisedView{}},
- {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), nil, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})},
- {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{}, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})},
- {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{'e'}, buffer.NewVectorisedView(5, []buffer.View{{'a', 'b', 'c', 'd'}, {'e'}})},
+ {vv(0), nil, vv(0)},
+ {vv(0), v(""), vv(0)},
+ {vv(4, "abcd"), nil, vv(4, "abcd")},
+ {vv(4, "abcd"), v(""), vv(4, "abcd")},
+ {vv(4, "abcd"), v("e"), vv(5, "abcd", "e")},
}
for _, tc := range testCases {
tc.vv.AppendView(tc.in)
@@ -580,6 +591,31 @@ func TestAppendView(t *testing.T) {
}
}
+func TestAppendViews(t *testing.T) {
+ testCases := []struct {
+ vv buffer.VectorisedView
+ in []buffer.View
+ want buffer.VectorisedView
+ }{
+ {vv(0), nil, vv(0)},
+ {vv(0), []buffer.View{}, vv(0)},
+ {vv(0), []buffer.View{v("")}, vv(0, "")},
+ {vv(4, "abcd"), nil, vv(4, "abcd")},
+ {vv(4, "abcd"), []buffer.View{}, vv(4, "abcd")},
+ {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")},
+ {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")},
+ {vv(4, "abcd"), []buffer.View{v("e")}, vv(5, "abcd", "e")},
+ {vv(4, "abcd"), []buffer.View{v("e"), v("fg")}, vv(7, "abcd", "e", "fg")},
+ {vv(4, "abcd"), []buffer.View{v(""), v("fg")}, vv(6, "abcd", "", "fg")},
+ }
+ for _, tc := range testCases {
+ tc.vv.AppendViews(tc.in)
+ if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
+ }
+ }
+}
+
func TestMemSize(t *testing.T) {
const perViewCap = 128
views := make([]buffer.View, 2, 32)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 07b4393a4..fc622b246 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -567,7 +567,7 @@ func TCPWindowLessThanEq(window uint16) TransportChecker {
}
// TCPFlags creates a checker that checks the tcp flags.
-func TCPFlags(flags uint8) TransportChecker {
+func TCPFlags(flags header.TCPFlags) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
@@ -576,15 +576,15 @@ func TCPFlags(flags uint8) TransportChecker {
t.Fatalf("TCP header not found in h: %T", h)
}
- if f := tcp.Flags(); f != flags {
- t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
+ if got := tcp.Flags(); got != flags {
+ t.Errorf("got tcp.Flags() = %s, want %s", got, flags)
}
}
}
// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
// given mask, match the supplied flags.
-func TCPFlagsMatch(flags, mask uint8) TransportChecker {
+func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
@@ -593,8 +593,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
t.Fatalf("TCP header not found in h: %T", h)
}
- if f := tcp.Flags(); (f & mask) != (flags & mask) {
- t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
+ if got := tcp.Flags(); (got & mask) != (flags & mask) {
+ t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask)
}
}
}
@@ -985,7 +985,11 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker {
}
icmp := header.ICMPv6(last.Payload())
- if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
+ if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: last.SourceAddress(),
+ Dst: last.DestinationAddress(),
+ }); got != want {
t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
}
diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go
index 3b7cc52f3..5d478ac32 100644
--- a/pkg/tcpip/errors.go
+++ b/pkg/tcpip/errors.go
@@ -300,6 +300,19 @@ func (*ErrInvalidOptionValue) IgnoreStats() bool {
}
func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" }
+// ErrInvalidPortRange indicates an attempt to set an invalid port range.
+//
+// +stateify savable
+type ErrInvalidPortRange struct{}
+
+func (*ErrInvalidPortRange) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrInvalidPortRange) IgnoreStats() bool {
+ return true
+}
+func (*ErrInvalidPortRange) String() string { return "invalid port range" }
+
// ErrMalformedHeader indicates the operation encountered a malformed header.
//
// +stateify savable
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 14a4b2b44..6aa9acfa8 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -186,42 +186,29 @@ func Checksum(buf []byte, initial uint16) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
- return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
+ var c Checksumer
+ for _, v := range vv.Views() {
+ c.Add([]byte(v))
+ }
+ return ChecksumCombine(initial, c.Checksum())
}
-// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
-// bytes in the given VectorizedView.
-//
-// The initial checksum must have been computed on an even number of bytes.
-func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
- odd := false
- sum := initial
- for _, v := range vv.Views() {
- if len(v) == 0 {
- continue
- }
-
- if off >= len(v) {
- off -= len(v)
- continue
- }
- v = v[off:]
-
- l := len(v)
- if l > size {
- l = size
- }
- v = v[:l]
-
- sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
-
- size -= len(v)
- if size == 0 {
- break
- }
- off = 0
+// Checksumer calculates checksum defined in RFC 1071.
+type Checksumer struct {
+ sum uint16
+ odd bool
+}
+
+// Add adds b to checksum.
+func (c *Checksumer) Add(b []byte) {
+ if len(b) > 0 {
+ c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum))
}
- return sum
+}
+
+// Checksum returns the latest checksum value.
+func (c *Checksumer) Checksum() uint16 {
+ return c.sum
}
// ChecksumCombine combines the two uint16 to form their checksum. This is done
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 5ab20ee86..d267dabd0 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -17,6 +17,7 @@
package header_test
import (
+ "bytes"
"fmt"
"math/rand"
"sync"
@@ -26,86 +27,72 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-func TestChecksumVVWithOffset(t *testing.T) {
+func TestChecksumer(t *testing.T) {
testCases := []struct {
- name string
- vv buffer.VectorisedView
- off, size int
- initial uint16
- want uint16
+ name string
+ data [][]byte
+ want uint16
}{
{
name: "empty",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- }),
- off: 0,
- size: 0,
want: 0,
},
{
- name: "OneView",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- }),
- off: 0,
- size: 5,
+ name: "OneOddView",
+ data: [][]byte{
+ []byte{1, 9, 0, 5, 4},
+ },
want: 1294,
},
{
- name: "TwoViews",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 0,
- size: 11,
+ name: "TwoOddViews",
+ data: [][]byte{
+ []byte{1, 9, 0, 5, 4},
+ []byte{4, 3, 7, 1, 2, 123},
+ },
want: 33819,
},
{
- name: "TwoViewsWithOffset",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 1,
- size: 11,
- want: 33819,
+ name: "OneEvenView",
+ data: [][]byte{
+ []byte{1, 9, 0, 5},
+ },
+ want: 270,
},
{
- name: "ThreeViewsWithOffset",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 7,
- size: 11,
- want: 33819,
+ name: "TwoEvenViews",
+ data: [][]byte{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0}),
+ buffer.NewViewFromBytes([]byte{9, 0, 5, 4}),
+ },
+ want: 30981,
},
{
- name: "ThreeViewsWithInitial",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
- }),
- initial: 77,
- off: 7,
- size: 11,
- want: 33896,
+ name: "ThreeViews",
+ data: [][]byte{
+ []byte{77, 11, 33, 0, 55, 44},
+ []byte{98, 1, 9, 0, 5, 4},
+ []byte{4, 3, 7, 1, 2, 123, 99},
+ },
+ want: 34236,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
- t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
+ var all bytes.Buffer
+ var c header.Checksumer
+ for _, b := range tc.data {
+ c.Add(b)
+ // Append to the buffer. We will check the checksum as a whole later.
+ if _, err := all.Write(b); err != nil {
+ t.Fatalf("all.Write(b) = _, %s; want _, nil", err)
+ }
+ }
+ if got, want := c.Checksum(), tc.want; got != want {
+ t.Errorf("c.Checksum() = %d, want %d", got, want)
}
- v := tc.vv.ToView()
- v.TrimFront(tc.off)
- v.CapLength(tc.size)
- if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
- t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
+ if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want {
+ t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want)
}
})
}
@@ -228,7 +215,7 @@ func TestICMPv4Checksum(t *testing.T) {
h.SetChecksum(want)
testICMPChecksum(t, h.Checksum, func() uint16 {
- return header.ICMPv4Checksum(h, vv)
+ return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0))
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
@@ -260,6 +247,12 @@ func TestICMPv6Checksum(t *testing.T) {
h.SetChecksum(want)
testICMPChecksum(t, h.Checksum, func() uint16 {
- return header.ICMPv6Checksum(h, src, dst, vv)
+ return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: h,
+ Src: src,
+ Dst: dst,
+ PayloadCsum: header.ChecksumVV(vv, 0),
+ PayloadLen: vv.Size(),
+ })
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index f840a4322..91c1c3cd2 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -18,7 +18,6 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
@@ -198,8 +197,8 @@ func (b ICMPv4) SetSequence(sequence uint16) {
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
-func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
- xsum := ChecksumVV(vv, 0)
+func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 {
+ xsum := payloadCsum
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = Checksum(h[:2], xsum)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index eca9750ab..668da623a 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -18,7 +18,6 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
@@ -262,12 +261,22 @@ func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
+// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
+type ICMPv6ChecksumParams struct {
+ Header ICMPv6
+ Src tcpip.Address
+ Dst tcpip.Address
+ PayloadCsum uint16
+ PayloadLen int
+}
+
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
-func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
+ h := params.Header
- xsum = ChecksumVV(vv, xsum)
+ xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen))
+ xsum = ChecksumCombine(xsum, params.PayloadCsum)
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = Checksum(h[:2], xsum)
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 2042f214a..ebb4b2c1d 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -41,7 +41,7 @@ func ARP(pkt *stack.PacketBuffer) bool {
//
// Returns true if the header was successfully parsed.
func IPv4(pkt *stack.PacketBuffer) bool {
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
if !ok {
return false
}
@@ -62,27 +62,29 @@ func IPv4(pkt *stack.PacketBuffer) bool {
ipHdr = header.IPv4(hdr)
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr))
return true
}
// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
// header with the IPv6 header.
func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
return 0, 0, 0, false, false
}
ipHdr := header.IPv6(hdr)
- // dataClone consists of:
+ // Create a VV to parse the packet. We don't plan to modify anything here.
+ // dataVV consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
views := [8]buffer.View{}
- dataClone := pkt.Data.Clone(views[:])
- dataClone.TrimFront(header.IPv6MinimumSize)
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+ dataVV := buffer.NewVectorisedView(0, views[:0])
+ dataVV.AppendViews(pkt.Data().Views())
+ dataVV.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataVV)
// Iterate over the IPv6 extensions to find their length.
var nextHdr tcpip.TransportProtocolNumber
@@ -98,7 +100,7 @@ traverseExtensions:
// If we exhaust the extension list, the entire packet is the IPv6 header
// and (possibly) extensions.
if done {
- extensionsSize = dataClone.Size()
+ extensionsSize = dataVV.Size()
break
}
@@ -110,12 +112,12 @@ traverseExtensions:
fragMore = extHdr.More()
}
rawPayload := it.AsRawHeader(true /* consume */)
- extensionsSize = dataClone.Size() - rawPayload.Buf.Size()
+ extensionsSize = dataVV.Size() - rawPayload.Buf.Size()
break traverseExtensions
case header.IPv6RawPayloadHeader:
// We've found the payload after any extensions.
- extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ extensionsSize = dataVV.Size() - extHdr.Buf.Size()
nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
break traverseExtensions
@@ -127,10 +129,10 @@ traverseExtensions:
// Put the IPv6 header with extensions in pkt.NetworkHeader().
hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
if !ok {
- panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size()))
}
ipHdr = header.IPv6(hdr)
- pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+ pkt.Data().CapLength(int(ipHdr.PayloadLength()))
pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
return nextHdr, fragID, fragOffset, fragMore, true
@@ -153,13 +155,13 @@ func UDP(pkt *stack.PacketBuffer) bool {
func TCP(pkt *stack.PacketBuffer) bool {
// TCP header is variable length, peek at it first.
hdrLen := header.TCPMinimumSize
- hdr, ok := pkt.Data.PullUp(hdrLen)
+ hdr, ok := pkt.Data().PullUp(hdrLen)
if !ok {
return false
}
// If the header has options, pull those up as well.
- if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() {
// TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
// packets.
hdrLen = offset
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 4c6f808e5..adc835d30 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -45,9 +45,23 @@ const (
TCPMaxSACKBlocks = 4
)
+// TCPFlags is the dedicated type for TCP flags.
+type TCPFlags uint8
+
+// String implements Stringer.String.
+func (f TCPFlags) String() string {
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if f&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ return string(flagsStr)
+}
+
// Flags that may be set in a TCP segment.
const (
- TCPFlagFin = 1 << iota
+ TCPFlagFin TCPFlags = 1 << iota
TCPFlagSyn
TCPFlagRst
TCPFlagPsh
@@ -94,7 +108,7 @@ type TCPFields struct {
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
- Flags uint8
+ Flags TCPFlags
// WindowSize is the "window size" field of a TCP packet.
WindowSize uint16
@@ -234,8 +248,8 @@ func (b TCP) Payload() []byte {
}
// Flags returns the flags field of the tcp header.
-func (b TCP) Flags() uint8 {
- return b[TCPFlagsOffset]
+func (b TCP) Flags() TCPFlags {
+ return TCPFlags(b[TCPFlagsOffset])
}
// WindowSize returns the "window size" field of the tcp header.
@@ -319,10 +333,10 @@ func (b TCP) ParsedOptions() TCPOptions {
return ParseTCPOptions(b.Options())
}
-func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
+func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
- b[TCPFlagsOffset] = flags
+ b[TCPFlagsOffset] = uint8(flags)
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
@@ -338,7 +352,7 @@ func (b TCP) Encode(t *TCPFields) {
// EncodePartial updates a subset of the fields of the tcp header. It is useful
// in cases when similar segments are produced.
-func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
+func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
// We don't use the flags field directly from the header because it's a
// one-byte field with an odd offset, so it would be accounted for
diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go
index 72563837b..96db8460f 100644
--- a/pkg/tcpip/header/tcp_test.go
+++ b/pkg/tcpip/header/tcp_test.go
@@ -146,3 +146,23 @@ func TestTCPParseOptions(t *testing.T) {
}
}
}
+
+func TestTCPFlags(t *testing.T) {
+ for _, tt := range []struct {
+ flags header.TCPFlags
+ want string
+ }{
+ {header.TCPFlagFin, "F "},
+ {header.TCPFlagSyn, " S "},
+ {header.TCPFlagRst, " R "},
+ {header.TCPFlagPsh, " P "},
+ {header.TCPFlagAck, " A "},
+ {header.TCPFlagUrg, " U"},
+ {header.TCPFlagSyn | header.TCPFlagAck, " S A "},
+ {header.TCPFlagFin | header.TCPFlagAck, "F A "},
+ } {
+ if got := tt.flags.String(); got != tt.want {
+ t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 72d3f70ac..e17e2085c 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -427,7 +427,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
vnetHdr.csumOffset = gso.CsumOffset
}
- if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS {
+ if gso.Type != stack.GSONone && uint16(pkt.Data().Size()) > gso.MSS {
switch gso.Type {
case stack.GSOTCPv4:
vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
@@ -468,7 +468,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
}
- if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS {
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
switch pkt.GSOOptions.Type {
case stack.GSOTCPv4:
vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 358a030d2..1e40f3fef 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -67,7 +67,7 @@ func checkPacketInfoEqual(t *testing.T, got, want packetInfo) {
LinkHeader: pk.LinkHeader().View(),
NetworkHeader: pk.NetworkHeader().View(),
TransportHeader: pk.TransportHeader().View(),
- Data: pk.Data.ToView(),
+ Data: pk.Data().AsRange().ToOwnedView(),
}
}),
); diff != "" {
@@ -616,8 +616,8 @@ func TestDispatchPacketFormat(t *testing.T) {
if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want {
t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want)
}
- if got, want := pkt.Data.Size(), 4; got != want {
- t.Errorf("pkt.Data.Size() = %d, want %d", got, want)
+ if got, want := pkt.Data().Size(), 4; got != want {
+ t.Errorf("pkt.Data().Size() = %d, want %d", got, want)
}
})
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 736871d1c..46df87f44 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -165,7 +165,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
// We don't get any indication of what the packet is, so try to guess
// if it's an IPv4 or IPv6 packet.
// IP version information is at the first octet, so pulling up 1 byte.
- h, ok := pkt.Data.PullUp(1)
+ h, ok := pkt.Data().PullUp(1)
if !ok {
return true, nil
}
@@ -270,7 +270,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
// We don't get any indication of what the packet is, so try to guess
// if it's an IPv4 or IPv6 packet.
// IP version information is at the first octet, so pulling up 1 byte.
- h, ok := pkt.Data.PullUp(1)
+ h, ok := pkt.Data().PullUp(1)
if !ok {
// Skip this packet.
continue
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index def47772f..d4b3ddd5c 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -80,7 +80,7 @@ func (q *queueBuffers) cleanup() {
type packetInfo struct {
addr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- vv buffer.VectorisedView
+ data buffer.View
linkHeader buffer.View
}
@@ -136,7 +136,7 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packets = append(c.packets, packetInfo{
addr: remoteLinkAddr,
proto: proto,
- vv: pkt.Data.Clone(nil),
+ data: pkt.Data().AsRange().ToOwnedView(),
})
c.mu.Unlock()
@@ -676,7 +676,7 @@ func TestSimpleReceive(t *testing.T) {
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
- rcvd := []byte(c.packets[0].vv.ToView())
+ rcvd := []byte(c.packets[0].data)
c.packets = c.packets[:0]
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index bd2b8d4bf..7aaee3d13 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -290,7 +290,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
break
}
@@ -327,7 +327,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize)
if !ok {
break
}
@@ -387,7 +387,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
- if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments {
+ if size := pkt.Data().Size() + len(tcp); offset > size && !moreFragments {
details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size)
break
}
@@ -398,13 +398,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
// Initialize the TCP flags.
flags := tcp.Flags()
- flagsStr := []byte("FSRPAU")
- for i := range flagsStr {
- if flags&(1<<uint(i)) == 0 {
- flagsStr[i] = ' '
- }
- }
- details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ details = fmt.Sprintf("flags: %s seqnum: %d ack: %d win: %d xsum:0x%x", flags, tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
if flags&header.TCPFlagSyn != 0 {
details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
} else {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 3829ca9c9..c1678c4f4 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -281,7 +281,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
vv.AppendView(info.Pkt.NetworkHeader().View())
vv.AppendView(info.Pkt.TransportHeader().View())
// Append data payload.
- vv.Append(info.Pkt.Data)
+ vv.Append(info.Pkt.Data().ExtractVV())
return vv.ToView(), true
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 3fcdea119..ae0461a6d 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -232,7 +232,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.mu.Lock()
- e.mu.dad.StopLocked(addr, false /* aborted */)
+ e.mu.dad.StopLocked(addr, &stack.DADDupAddrDetected{HolderLinkAddress: linkAddr})
e.mu.Unlock()
// The solicited, override, and isRouter flags are not available for ARP;
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
index 243738951..5168f5361 100644
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
@@ -170,7 +170,7 @@ func (f *Fragmentation) Process(
return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
- if l := pkt.Data.Size(); l != int(fragmentSize) {
+ if l := pkt.Data().Size(); l != int(fragmentSize) {
return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
@@ -293,7 +293,7 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, re
// these headers.
var fragmentableData buffer.VectorisedView
fragmentableData.AppendView(pkt.TransportHeader().View())
- fragmentableData.Append(pkt.Data)
+ fragmentableData.Append(pkt.Data().ExtractVV())
fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
return PacketFragmenter{
@@ -323,7 +323,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int,
})
// Copy data for the fragment.
- copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen)
+ copied := fragPkt.Data().ReadFromVV(&pf.data, pf.fragmentPayloadLen)
offset := pf.fragmentOffset
pf.fragmentOffset += copied
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
index 47ea3173e..7daf64b4a 100644
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
@@ -121,7 +121,7 @@ func TestFragmentationProcess(t *testing.T) {
in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
}
if c.out[i].done {
- if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" {
+ if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" {
t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
}
@@ -470,9 +470,7 @@ func TestPacketFragmenter(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
- var originalPayload buffer.VectorisedView
- originalPayload.AppendView(pkt.TransportHeader().View())
- originalPayload.Append(pkt.Data)
+ originalPayload := stack.PayloadSince(pkt.TransportHeader())
var reassembledPayload buffer.VectorisedView
pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
for i := 0; ; i++ {
@@ -499,7 +497,7 @@ func TestPacketFragmenter(t *testing.T) {
if got := fragPkt.TransportHeader().View().Size(); got != 0 {
t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
}
- reassembledPayload.Append(fragPkt.Data)
+ reassembledPayload.AppendViews(fragPkt.Data().Views())
if !more {
if i != len(test.wantFragments)-1 {
t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
@@ -507,7 +505,7 @@ func TestPacketFragmenter(t *testing.T) {
break
}
}
- if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" {
t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
})
@@ -625,11 +623,11 @@ func TestTimeoutHandler(t *testing.T) {
}
switch {
case handler.pkt != nil && test.wantPkt == nil:
- t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView())
+ t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView())
case handler.pkt == nil && test.wantPkt != nil:
- t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView())
+ t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView())
case handler.pkt != nil && test.wantPkt != nil:
- if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" {
+ if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" {
t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
}
}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
index 933d63d32..90075a70c 100644
--- a/pkg/tcpip/network/internal/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
@@ -167,8 +167,8 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
resPkt := r.holes[0].pkt
for i := 1; i < len(r.holes); i++ {
- fragPkt := r.holes[i].pkt
- fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size())
+ fragData := r.holes[i].pkt.Data()
+ resPkt.Data().ReadFromData(fragData, fragData.Size())
}
return resPkt, r.proto, true, memConsumed, nil
}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
index 214a93709..cfd9f00ef 100644
--- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
@@ -204,7 +204,7 @@ func TestReassemblerProcess(t *testing.T) {
if a == nil || b == nil {
return a == b
}
- return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView())
+ return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView())
}
if isDone {
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index 6f89a6a16..0053646ee 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -126,9 +126,12 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
s.timer.Stop()
delete(d.addresses, addr)
- r := stack.DADResult{Resolved: dadDone, Err: err}
+ var res stack.DADResult = &stack.DADSucceeded{}
+ if err != nil {
+ res = &stack.DADError{Err: err}
+ }
for _, h := range s.completionHandlers {
- h(r)
+ h(res)
}
}),
}
@@ -142,7 +145,7 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
// StopLocked stops a currently running DAD process.
//
// Precondition: d.protocolMU must be locked.
-func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) {
+func (d *DAD) StopLocked(addr tcpip.Address, reason stack.DADResult) {
s, ok := d.addresses[addr]
if !ok {
return
@@ -152,14 +155,8 @@ func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) {
s.timer.Stop()
delete(d.addresses, addr)
- var err tcpip.Error
- if aborted {
- err = &tcpip.ErrAborted{}
- }
-
- r := stack.DADResult{Resolved: false, Err: err}
for _, h := range s.completionHandlers {
- h(r)
+ h(reason)
}
}
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
index 18c357b56..e00aa4678 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -78,10 +78,10 @@ func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADC
return m.mu.dad.CheckDuplicateAddressLocked(addr, h)
}
-func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) {
+func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) {
m.mu.Lock()
defer m.mu.Unlock()
- m.mu.dad.StopLocked(addr, aborted)
+ m.mu.dad.StopLocked(addr, reason)
}
func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
@@ -175,7 +175,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
}
clock.Advance(delta)
for i := 0; i < 2; i++ {
- if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff)
}
}
@@ -189,7 +189,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
default:
}
clock.Advance(delta)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
@@ -202,7 +202,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
clock.Advance(dadConfig2Duration)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
@@ -241,19 +241,19 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
- dad.stop(addr1, true /* aborted */)
- if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" {
+ dad.stop(addr1, &stack.DADAborted{})
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
- dad.stop(addr2, false /* aborted */)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" {
+ dad.stop(addr2, &stack.DADDupAddrDetected{})
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
@@ -266,7 +266,7 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 5f7e60c5c..b6f39ddb1 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -69,8 +69,8 @@ type MultiCounterIPStats struct {
IPTablesOutputDropped tcpip.MultiCounterStat
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
-
// of IPStats.
+
// OptionTimestampReceived is the number of Timestamp options seen.
OptionTimestampReceived tcpip.MultiCounterStat
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 90236ed9e..aee1652fa 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -90,8 +90,7 @@ type testObject struct {
// checkValues verifies that the transport protocol, data contents, src & dst
// addresses of a packet match what's expected. If any field doesn't match, the
// test fails.
-func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
- v := vv.ToView()
+func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) {
if protocol != t.protocol {
t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
}
@@ -120,7 +119,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// parsing are expected.
func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
netHdr := pkt.Network()
- t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress())
+ t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress())
t.dataCalls++
return stack.TransportPacketHandled
}
@@ -129,7 +128,7 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) {
- t.checkValues(trans, pkt.Data, remote, local)
+ t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local)
if diff := cmp.Diff(
t.transErr,
transportError{
@@ -198,7 +197,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
}
- t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
+ t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr)
return nil
}
@@ -371,7 +370,11 @@ func TestSourceAddressValidation(t *testing.T) {
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetCode(0)
pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: src,
+ Dst: localIPv6Addr,
+ }))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: header.ICMPv6MinimumSize,
@@ -1199,7 +1202,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
nic.testObject.transErr = c.transErr
// Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: outerSrcAddr,
+ Dst: localIPv6Addr,
+ }))
addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
if !ok {
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 4b21ee79c..5e7f10f4b 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -32,12 +32,14 @@ go_test(
"ipv4_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/internal/testutil",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index bd0eabad1..deb104837 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -137,7 +137,7 @@ func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
// is used to find out which transport endpoint must be notified about the ICMP
// packet. We only expect the payload, not the enclosing ICMP packet.
func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) {
- h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
if !ok {
return
}
@@ -156,7 +156,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
}
hlen := int(hdr.HeaderLength())
- if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
+ if pkt.Data().Size() < hlen || hdr.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@@ -164,7 +164,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
}
// Skip the ip header, then deliver the error.
- pkt.Data.TrimFront(hlen)
+ pkt.Data().TrimFront(hlen)
p := hdr.TransportProtocol()
e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt)
}
@@ -174,7 +174,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
- v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
return
@@ -182,7 +182,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
h := header.ICMPv4(v)
// Only do in-stack processing if the checksum is correct.
- if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
+ if pkt.Data().AsRange().Checksum() != 0xffff {
received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
@@ -238,7 +238,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
- received.echo.Increment()
+ received.echoRequest.Increment()
sent := e.stats.icmp.packetsSent
if !e.protocol.stack.AllowICMPMessage() {
@@ -253,7 +253,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
- replyData := pkt.Data.ToOwnedView()
+ replyData := pkt.Data().AsRange().ToOwnedView()
ipHdr := header.IPv4(pkt.NetworkHeader().View())
localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
@@ -336,7 +336,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4DstUnreachable:
received.dstUnreachable.Increment()
- pkt.Data.TrimFront(header.ICMPv4MinimumSize)
+ pkt.Data().TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
@@ -571,7 +571,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return nil
}
- payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size()
+ payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data().Size()
if payloadLen > available {
payloadLen = available
}
@@ -586,8 +586,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
newHeader := append(buffer.View(nil), origIPHdr...)
newHeader = append(newHeader, transportHeader...)
payload := newHeader.ToVectorisedView()
- payload.AppendView(pkt.Data.ToView())
- payload.CapLength(payloadLen)
+ if dataCap := payloadLen - payload.Size(); dataCap > 0 {
+ payload.AppendView(pkt.Data().AsRange().Capped(dataCap).ToOwnedView())
+ } else {
+ payload.CapLength(payloadLen)
+ }
icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize,
@@ -623,7 +626,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
- icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
nil, /* gso */
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index 0a15ae897..f3fc1c87e 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -197,7 +197,7 @@ func (igmp *igmpState) isPacketValidLocked(pkt *stack.PacketBuffer, messageType
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption bool) {
received := igmp.ep.stats.igmp.packetsReceived
- headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
+ headerView, ok := pkt.Data().PullUp(header.IGMPMinimumSize)
if !ok {
received.invalid.Increment()
return
@@ -210,7 +210,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption
// same set of octets, including the checksum field. If the result
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
- if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF {
+ if pkt.Data().AsRange().Checksum() != 0xFFFF {
received.checksumErrors.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index c5f68e411..e5e1b89cc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -106,9 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
igmp.SetGroupAddress(groupAddress)
igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
- e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4a429ea6c..8a2140ebe 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -492,7 +492,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
- h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
if !ok {
return &tcpip.ErrMalformedHeader{}
}
@@ -502,14 +502,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return &tcpip.ErrMalformedHeader{}
}
- h, ok = pkt.Data.PullUp(int(hdrLen))
+ h, ok = pkt.Data().PullUp(int(hdrLen))
if !ok {
return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv4(h)
// Always set the total length.
- pktSize := pkt.Data.Size()
+ pktSize := pkt.Data().Size()
ip.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
@@ -687,7 +687,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats := e.stats
h := header.IPv4(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
+ if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -765,7 +765,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
if h.More() || h.FragmentOffset() != 0 {
- if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
+ if pkt.Data().Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
stats.ip.MalformedPacketsReceived.Increment()
@@ -793,10 +793,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// maximum payload size.
//
// Note that this addition doesn't overflow even on 32bit architecture
- // because pkt.Data.Size() should not exceed 65535 (the max IP datagram
+ // because pkt.Data().Size() should not exceed 65535 (the max IP datagram
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
- if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
+ if int(start)+pkt.Data().Size() > header.IPv4MaximumPayloadSize {
stats.ip.MalformedPacketsReceived.Increment()
stats.ip.MalformedFragmentsReceived.Increment()
return
@@ -813,7 +813,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
Protocol: proto,
},
start,
- start+uint16(pkt.Data.Size())-1,
+ start+uint16(pkt.Data().Size())-1,
h.More(),
proto,
pkt,
@@ -831,7 +831,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
- h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
+ h.SetTotalLength(uint16(pkt.Data().Size() + len((h))))
h.SetFlagsFragmentOffset(0, 0)
}
stats.ip.PacketsDelivered.Increment()
@@ -899,10 +899,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {
e.mu.Lock()
- defer e.mu.Unlock()
-
e.disableLocked()
e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
e.protocol.forgetEndpoint(e.nic.ID())
}
@@ -1186,7 +1185,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error
}
func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
- payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ payload := pkt.TransportHeader().View().Size() + pkt.Data().Size()
return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index dc4db6e5f..cfed241bf 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -26,12 +26,14 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
@@ -1211,7 +1213,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset)
}
reassembledPayload.AppendView(packet.TransportHeader().View())
- reassembledPayload.Append(packet.Data)
+ reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView())
// Clear out the checksum and length from the ip because we can't compare
// it.
sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
@@ -2985,3 +2987,120 @@ func TestPacketQueing(t *testing.T) {
})
}
}
+
+// TestCloseLocking test that lock ordering is followed when closing an
+// endpoint.
+func TestCloseLocking(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
+
+ iterations = 1000
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ // Perform NAT so that the endoint tries to search for a sibling endpoint
+ // which ends up taking the protocol and endpoint lock (in that order).
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 1,
+ stack.Forward: stack.HookUnset,
+ stack.Output: 2,
+ stack.Postrouting: 3,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 1,
+ stack.Forward: stack.HookUnset,
+ stack.Output: 2,
+ stack.Postrouting: 3,
+ },
+ }
+ if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil {
+ t.Fatalf("s.IPTables().ReplaceTable(...): %s", err)
+ }
+
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ }})
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53}
+ if err := ep.Connect(addr); err != nil {
+ t.Errorf("ep.Connect(%#v): %s", addr, err)
+ }
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ // Writing packets should trigger NAT which requires the stack to search the
+ // protocol for network endpoints with the destination address.
+ //
+ // Creating and removing interfaces should modify the protocol and endpoint
+ // which requires taking the locks of each.
+ //
+ // We expect the protocol > endpoint lock ordering to be followed here.
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ data := []byte{1, 2, 3, 4}
+
+ for i := 0; i < iterations; i++ {
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Errorf("ep.Write(_, _): %s", err)
+ return
+ } else if want := int64(len(data)); n != want {
+ t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
+ return
+ }
+ }
+ }()
+ go func() {
+ defer wg.Done()
+
+ for i := 0; i < iterations; i++ {
+ if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
+ return
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Errorf("RemoveNIC(%d): %s", nicID2, err)
+ return
+ }
+ }
+ }()
+}
diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go
index 5ae73fbfb..5798cfec6 100644
--- a/pkg/tcpip/network/ipv4/stats.go
+++ b/pkg/tcpip/network/ipv4/stats.go
@@ -52,7 +52,7 @@ type sharedStats struct {
// LINT.IfChange(multiCounterICMPv4PacketStats)
type multiCounterICMPv4PacketStats struct {
- echo tcpip.MultiCounterStat
+ echoRequest tcpip.MultiCounterStat
echoReply tcpip.MultiCounterStat
dstUnreachable tcpip.MultiCounterStat
srcQuench tcpip.MultiCounterStat
@@ -66,7 +66,7 @@ type multiCounterICMPv4PacketStats struct {
}
func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) {
- m.echo.Init(a.Echo, b.Echo)
+ m.echoRequest.Init(a.EchoRequest, b.EchoRequest)
m.echoReply.Init(a.EchoReply, b.EchoReply)
m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
m.srcQuench.Init(a.SrcQuench, b.SrcQuench)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 5f44ab317..6344a3e09 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -18,7 +18,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -165,7 +164,7 @@ func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) {
- h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ h, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
return
}
@@ -184,10 +183,10 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data.TrimFront(header.IPv6MinimumSize)
+ pkt.Data().TrimFront(header.IPv6MinimumSize)
p := hdr.TransportProtocol()
if p == header.IPv6FragmentHeader {
- f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
+ f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
return
}
@@ -200,7 +199,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
+ pkt.Data().TrimFront(header.IPv6FragmentHeaderSize)
p = fragHdr.TransportProtocol()
}
@@ -268,7 +267,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD {
return false
}
- if pkt.Data.Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize {
+ if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize {
return false
}
if iph.HopLimit() != header.MLDHopLimit {
@@ -285,7 +284,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received := e.stats.icmp.packetsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
// fields set. See icmp/protocol.go:protocol.Parse for a full explanation.
- v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
+ v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
if !ok {
received.invalid.Increment()
return
@@ -296,11 +295,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
dstAddr := iph.DestinationAddress()
// Validate ICMPv6 checksum before processing the packet.
- //
- // This copy is used as extra payload during the checksum calculation.
- payload := pkt.Data.Clone(nil)
- payload.TrimFront(len(h))
- if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want {
+ payload := pkt.Data().AsRange().SubRange(len(h))
+ if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: h,
+ Src: srcAddr,
+ Dst: dstAddr,
+ PayloadCsum: payload.Checksum(),
+ PayloadLen: payload.Size(),
+ }); got != want {
received.invalid.Increment()
return
}
@@ -320,12 +322,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.packetTooBig.Increment()
- hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
received.invalid.Increment()
return
}
- pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize)
networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
@@ -334,12 +336,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
- hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
received.invalid.Increment()
return
}
- pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
@@ -348,16 +350,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
case header.ICMPv6NeighborSolicit:
received.neighborSolicit.Increment()
- if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
+ if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize {
received.invalid.Increment()
return
}
// The remainder of payload must be only the neighbor solicitation, so
- // payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
+ // payload.AsView() always returns the solicitation. Per RFC 6980 section 5,
// NDP messages cannot be fragmented. Also note that in the common case NDP
- // datagrams are very small and ToView() will not incur allocations.
- ns := header.NDPNeighborSolicit(payload.ToView())
+ // datagrams are very small and AsView() will not incur allocations.
+ ns := header.NDPNeighborSolicit(payload.AsView())
targetAddr := ns.TargetAddress()
// As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
@@ -380,6 +382,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if srcAddr == header.IPv6Any {
+ // Since this is a DAD message we know the sender does not actually hold
+ // the target address so there is no "holder".
+ var holderLinkAddress tcpip.LinkAddress
+
// We would get an error if the address no longer exists or the address
// is no longer tentative (DAD resolved between the call to
// hasTentativeAddr and this point). Both of these are valid scenarios:
@@ -391,7 +397,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
@@ -529,7 +535,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
na.Options().Serialize(optsSerializer)
- packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: packet,
+ Src: r.LocalAddress,
+ Dst: r.RemoteAddress,
+ }))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
//
@@ -545,20 +555,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6NeighborAdvert:
received.neighborAdvert.Increment()
- if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
+ if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize {
received.invalid.Increment()
return
}
// The remainder of payload must be only the neighbor advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // payload.AsView() always returns the advertisement. Per RFC 6980 section
// 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
- na := header.NDPNeighborAdvert(payload.ToView())
+ // NDP datagrams are very small and AsView() will not incur allocations.
+ na := header.NDPNeighborAdvert(payload.AsView())
+
+ it, err := na.Options().Iter(false /* check */)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.invalid.Increment()
+ return
+ }
+
+ targetLinkAddr, ok := getTargetLinkAddr(it)
+ if !ok {
+ received.invalid.Increment()
+ return
+ }
+
targetAddr := na.TargetAddress()
e.dad.mu.Lock()
- e.dad.mu.dad.StopLocked(targetAddr, false /* aborted */)
+ e.dad.mu.dad.StopLocked(targetAddr, &stack.DADDupAddrDetected{HolderLinkAddress: targetLinkAddr})
e.dad.mu.Unlock()
if e.hasTentativeAddr(targetAddr) {
@@ -578,7 +602,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
return
default:
@@ -586,13 +610,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
}
- it, err := na.Options().Iter(false /* check */)
- if err != nil {
- // If we have a malformed NDP NA option, drop the packet.
- received.invalid.Increment()
- return
- }
-
// At this point we know that the target address is not tentative on the
// NIC. However, the target address may still be assigned to the NIC but not
// tentative (it could be permanent). Such a scenario is beyond the scope of
@@ -602,11 +619,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// TODO(b/143147598): Handle the scenario described above. Also inform the
// netstack integration that a duplicate address was detected outside of
// DAD.
- targetLinkAddr, ok := getTargetLinkAddr(it)
- if !ok {
- received.invalid.Increment()
- return
- }
// As per RFC 4861 section 7.1.2:
// A node MUST silently discard any received Neighbor Advertisement
@@ -657,13 +669,20 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
- Data: pkt.Data,
+ Data: pkt.Data().ExtractVV(),
})
- packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
+ icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- copy(packet, icmpHdr)
- packet.SetType(header.ICMPv6EchoReply)
- packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
+ copy(icmp, icmpHdr)
+ icmp.SetType(header.ICMPv6EchoReply)
+ dataRange := replyPkt.Data().AsRange()
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: r.LocalAddress,
+ Dst: r.RemoteAddress,
+ PayloadCsum: dataRange.Checksum(),
+ PayloadLen: dataRange.Size(),
+ }))
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: r.DefaultTTL(),
@@ -676,7 +695,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6EchoReply:
received.echoReply.Increment()
- if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
+ if pkt.Data().Size() < header.ICMPv6EchoMinimumSize {
received.invalid.Increment()
return
}
@@ -696,7 +715,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// Is the NDP payload of sufficient size to hold a Router Solictation?
- if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
+ if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.invalid.Increment()
return
}
@@ -710,9 +729,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- // Note that in the common case NDP datagrams are very small and ToView()
+ // Note that in the common case NDP datagrams are very small and AsView()
// will not incur allocations.
- rs := header.NDPRouterSolicit(payload.ToView())
+ rs := header.NDPRouterSolicit(payload.AsView())
it, err := rs.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
@@ -756,7 +775,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// Is the NDP payload of sufficient size to hold a Router Advertisement?
- if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
received.invalid.Increment()
return
}
@@ -770,9 +789,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- // Note that in the common case NDP datagrams are very small and ToView()
+ // Note that in the common case NDP datagrams are very small and AsView()
// will not incur allocations.
- ra := header.NDPRouterAdvert(payload.ToView())
+ ra := header.NDPRouterAdvert(payload.AsView())
it, err := ra.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
@@ -850,11 +869,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
e.mu.Lock()
- e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView()))
e.mu.Unlock()
case header.ICMPv6MulticastListenerReport:
e.mu.Lock()
- e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView()))
e.mu.Unlock()
case header.ICMPv6MulticastListenerDone:
default:
@@ -1077,13 +1096,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
if available < header.IPv6MinimumSize {
return nil
}
- payloadLen := network.Size() + transport.Size() + pkt.Data.Size()
+ payloadLen := network.Size() + transport.Size() + pkt.Data().Size()
if payloadLen > available {
payloadLen = available
}
payload := network.ToVectorisedView()
payload.AppendView(transport)
- payload.Append(pkt.Data)
+ payload.Append(pkt.Data().ExtractVV())
payload.CapLength(payloadLen)
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1115,7 +1134,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ dataRange := newPkt.Data().AsRange()
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpHdr,
+ Src: route.LocalAddress,
+ Dst: route.RemoteAddress,
+ PayloadCsum: dataRange.Checksum(),
+ PayloadLen: dataRange.Size(),
+ }))
if err := route.WritePacket(
nil, /* gso */
stack.NetworkHeaderParams{
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index c27164344..d4e63710c 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -324,7 +324,13 @@ func TestICMPCounts(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp[:typ.size],
+ Src: lladdr0,
+ Dst: lladdr1,
+ PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */),
+ PayloadLen: len(typ.extraData),
+ }))
handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert)
}
@@ -498,7 +504,11 @@ func TestLinkResolution(t *testing.T) {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: r.LocalAddress,
+ Dst: r.RemoteAddress,
+ }))
// We can't send our payload directly over the route because that
// doesn't provoke NDP discovery.
@@ -687,7 +697,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: lladdr1,
+ Dst: lladdr0,
+ }))
}
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -879,7 +893,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
payloadFn(icmpHdr.Payload())
if checksum {
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{}))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpHdr,
+ Src: lladdr1,
+ Dst: lladdr0,
+ }))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1058,7 +1076,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
payloadFn(payload)
if checksum {
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView()))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpHdr,
+ Src: lladdr1,
+ Dst: lladdr0,
+ PayloadCsum: header.Checksum(payload, 0 /* initial */),
+ PayloadLen: len(payload),
+ }))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1324,7 +1348,11 @@ func TestPacketQueing(t *testing.T) {
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetCode(0)
pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: header.ICMPv6MinimumSize,
@@ -1422,7 +1450,11 @@ func TestPacketQueing(t *testing.T) {
na.Options().Serialize(header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
})
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -1661,7 +1693,11 @@ func TestCallsToNeighborCache(t *testing.T) {
}
icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: test.source,
+ Dst: test.destination,
+ }))
handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false)
// Confirm the endpoint calls the correct NUDHandler method.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 7638ade35..46b6cc41a 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -363,7 +363,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
// attempt will be made to generate a new address for it.
- if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, true /* dadFailure */); err != nil {
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil {
return err
}
@@ -536,8 +536,20 @@ func (e *endpoint) disableLocked() {
}
e.mu.ndp.stopSolicitingRouters()
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr, &stack.DADAborted{})
+ }
+
+ return true
+ })
e.mu.ndp.cleanupState(false /* hostOnly */)
- e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
@@ -555,25 +567,6 @@ func (e *endpoint) disableLocked() {
}
}
-// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
-//
-// Precondition: e.mu must be write locked.
-func (e *endpoint) stopDADForPermanentAddressesLocked() {
- // Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- return true
- }
-
- addr := addressEndpoint.AddressWithPrefix().Address
- if header.IsV6UnicastAddress(addr) {
- e.mu.ndp.stopDuplicateAddressDetection(addr, false /* failed */)
- }
-
- return true
- })
-}
-
// DefaultTTL is the default hop limit for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
return e.protocol.DefaultTTL()
@@ -619,7 +612,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params
}
func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
- payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ payload := pkt.TransportHeader().View().Size() + pkt.Data().Size()
return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
@@ -819,14 +812,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required checks.
- h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ h, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv6(h)
// Always set the payload length.
- pktSize := pkt.Data.Size()
+ pktSize := pkt.Data().Size()
ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
// Set the source address when zero.
@@ -964,7 +957,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats := e.stats.ip
h := header.IPv6(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
+ if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
stats.MalformedPacketsReceived.Increment()
return
}
@@ -993,13 +986,14 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
+ // Create a VV to parse the packet. We don't plan to modify anything here.
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView()
vv.AppendView(pkt.TransportHeader().View())
- vv.Append(pkt.Data)
+ vv.AppendViews(pkt.Data().Views())
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
// iptables filtering. All packets that reach here are intended for
@@ -1257,7 +1251,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// have more extension headers in the reassembled payload, as per RFC
// 8200 section 4.5. We also use the NextHeader value from the first
// fragment.
- it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data)
+ data := pkt.Data()
+ dataVV := buffer.NewVectorisedView(data.Size(), data.Views())
+ it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), dataVV)
}
case header.IPv6DestinationOptionsExtHdr:
@@ -1314,7 +1310,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
- pkt.Data = extHdr.Buf
+ pkt.Data().Replace(extHdr.Buf)
stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
@@ -1381,8 +1377,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
func (e *endpoint) Close() {
e.mu.Lock()
e.disableLocked()
- e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
- e.stopDADForPermanentAddressesLocked()
e.mu.addressableEndpointState.Cleanup()
e.mu.Unlock()
@@ -1448,14 +1442,14 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
return &tcpip.ErrBadLocalAddress{}
}
- return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, false /* dadFailure */)
+ return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, &stack.DADAborted{})
}
// removePermanentEndpointLocked is like removePermanentAddressLocked except
// it works with a stack.AddressEndpoint.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation, dadFailure bool) tcpip.Error {
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool, dadResult stack.DADResult) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
@@ -1466,16 +1460,16 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr)
}
- return e.removePermanentEndpointInnerLocked(addressEndpoint, dadFailure)
+ return e.removePermanentEndpointInnerLocked(addressEndpoint, dadResult)
}
// removePermanentEndpointInnerLocked is like removePermanentEndpointLocked
// except it does not cleanup SLAAC address state.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadFailure bool) tcpip.Error {
+func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadResult stack.DADResult) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
- e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadFailure)
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadResult)
if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
return err
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 05c9f4dbf..266a53e3b 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -68,7 +68,11 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: src,
+ Dst: dst,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -216,7 +220,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
// Store the reassembled payload as we parse each fragment. The payload
// includes the Transport header and everything after.
reassembledPayload.AppendView(fragment.TransportHeader().View())
- reassembledPayload.Append(fragment.Data)
+ reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView())
}
if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" {
@@ -3065,7 +3069,11 @@ func TestForwarding(t *testing.T) {
icmp.SetType(header.ICMPv6EchoRequest)
icmp.SetCode(header.ICMPv6UnusedCode)
icmp.SetChecksum(0)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: remoteIPv6Addr1,
+ Dst: remoteIPv6Addr2,
+ }))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: header.ICMPv6MinimumSize,
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index 205e36cdd..dd153466d 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -236,7 +236,11 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
localAddress = header.IPv6Any
}
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: localAddress,
+ Dst: destAddress,
+ }))
extensionHeaders := header.IPv6ExtHdrSerializer{
header.IPv6SerializableHopByHopExtHdr{
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index f1b8d58f2..9a425e50a 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -326,11 +326,15 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho
mld := header.MLD(icmp.MessageBody())
mld.SetMaximumResponseDelay(0)
mld.SetMulticastAddress(header.IPv6Any)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddress, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: srcAddress,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
- e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
func TestMLDPacketValidation(t *testing.T) {
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 721269c58..d9b728878 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -208,16 +208,12 @@ const (
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
- // OnDuplicateAddressDetectionStatus is called when the DAD process for an
- // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
- // if DAD completed successfully (no duplicate addr detected); false otherwise
- // (addr was detected to be a duplicate on the link the NIC is a part of, or
- // it was stopped for some other reason, such as the address being removed).
- // If an error occured during DAD, err is set and resolved must be ignored.
+ // OnDuplicateAddressDetectionResult is called when the DAD process for an
+ // address on a NIC completes.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error)
+ OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult)
// OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
@@ -225,14 +221,14 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+ OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool
// OnDefaultRouterInvalidated is called when a discovered default router that
// was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+ OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address)
// OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
// Implementations must return true if the newly discovered on-link prefix
@@ -240,14 +236,14 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+ OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool
// OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
// was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+ OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet)
// OnAutoGenAddress is called when a new prefix with its autonomous address-
// configuration flag set is received and SLAAC was performed. Implementations
@@ -280,12 +276,12 @@ type NDPDispatcher interface {
// It is up to the caller to use the DNS Servers only for their valid
// lifetime. OnRecursiveDNSServerOption may be called for new or
// already known DNS servers. If called with known DNS servers, their
- // valid lifetimes must be refreshed to lifetime (it may be increased,
- // decreased, or completely invalidated when lifetime = 0).
+ // valid lifetimes must be refreshed to the lifetime (it may be increased,
+ // decreased, or completely invalidated when the lifetime = 0).
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
- OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+ OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration)
// OnDNSSearchListOption is called when the stack learns of DNS search lists
// through NDP.
@@ -293,9 +289,9 @@ type NDPDispatcher interface {
// It is up to the caller to use the domain names in the search list
// for only their valid lifetime. OnDNSSearchListOption may be called
// with new or already known domain names. If called with known domain
- // names, their valid lifetimes must be refreshed to lifetime (it may
- // be increased, decreased or completely invalidated when lifetime = 0.
- OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+ // names, their valid lifetimes must be refreshed to the lifetime (it may
+ // be increased, decreased or completely invalidated when the lifetime = 0.
+ OnDNSSearchListOption(tcpip.NICID, []string, time.Duration)
// OnDHCPv6Configuration is called with an updated configuration that is
// available via DHCPv6 for the passed NIC.
@@ -587,15 +583,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
- if r.Resolved {
+ var dadSucceeded bool
+ switch r.(type) {
+ case *stack.DADAborted, *stack.DADError, *stack.DADDupAddrDetected:
+ dadSucceeded = false
+ case *stack.DADSucceeded:
+ dadSucceeded = true
+ default:
+ panic(fmt.Sprintf("unrecognized DAD result = %T", r))
+ }
+
+ if dadSucceeded {
addressEndpoint.SetKind(stack.Permanent)
}
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, r.Resolved, r.Err)
+ ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, r)
}
- if r.Resolved {
+ if dadSucceeded {
if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
// Reset the generation attempts counter as we are starting the
// generation of a new address for the SLAAC prefix.
@@ -616,7 +622,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
+ ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, &stack.DADSucceeded{})
}
ndp.ep.onAddressAssignedLocked(addr)
@@ -633,8 +639,8 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// of this function to handle such a scenario.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, failed bool) {
- ndp.dad.StopLocked(addr, !failed)
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, reason stack.DADResult) {
+ ndp.dad.StopLocked(addr, reason)
}
// handleRA handles a Router Advertisement message that arrived on the NIC
@@ -1501,7 +1507,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
- if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, false /* dadFailure */); err != nil {
+ if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, &stack.DADAborted{}); err != nil {
panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
}
}
@@ -1560,7 +1566,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
ndp.cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs, tempAddr, tempAddrState)
- if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, false /* dadFailure */); err != nil {
+ if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, &stack.DADAborted{}); err != nil {
panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
}
}
@@ -1721,7 +1727,11 @@ func (ndp *ndpState) startSolicitingRouters() {
icmpData.SetType(header.ICMPv6RouterSolicit)
rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpData,
+ Src: localAddr,
+ Dst: header.IPv6AllRoutersMulticastAddress,
+ }))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
@@ -1812,7 +1822,11 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL
ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(opts)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddr, dstAddr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: srcAddr,
+ Dst: dstAddr,
+ }))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.MaxHeaderLength()),
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index ce20af0e3..6e850fd46 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -90,7 +90,7 @@ type testNDPDispatcher struct {
addr tcpip.Address
}
-func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
+func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
@@ -215,7 +215,11 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: lladdr1,
+ Dst: lladdr0,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -478,7 +482,11 @@ func TestNeighborSolicitationResponse(t *testing.T) {
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: test.nsSrc,
+ Dst: test.nsDst,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -554,7 +562,11 @@ func TestNeighborSolicitationResponse(t *testing.T) {
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
na.Options().Serialize(ser)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: test.nsSrc,
+ Dst: nicAddr,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -657,7 +669,11 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: lladdr1,
+ Dst: lladdr0,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -874,7 +890,13 @@ func TestNDPValidation(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp[:typ.size],
+ Src: lladdr0,
+ Dst: lladdr1,
+ PayloadCsum: header.Checksum(typ.extraData /* initial */, 0),
+ PayloadLen: len(typ.extraData),
+ }))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -987,7 +1009,11 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetTargetAddress(lladdr1)
na.SetSolicitedFlag(test.solicitedFlag)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: lladdr1,
+ Dst: test.ipDstAddr,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -1182,7 +1208,11 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt.SetCode(test.code)
copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: test.src,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
@@ -1284,10 +1314,10 @@ func TestCheckDuplicateAddress(t *testing.T) {
t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning)
}
- // Wait for DAD to resolve.
+ // Wait for DAD to complete.
clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
for i := 0; i < dadRequestsMade; i++ {
- if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" {
+ if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" {
t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
}
}
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 73913aef8..ecd5003a7 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -230,9 +230,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime b
igmp.SetGroupAddress(groupAddress)
igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
- e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
// createAndInjectMLDPacket creates and injects an MLD packet with the
@@ -263,11 +263,15 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay b
mld := header.MLD(icmp.MessageBody())
mld.SetMaximumResponseDelay(uint16(maxRespDelay))
mld.SetMulticastAddress(groupAddress)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, linkLocalIPv6Addr2, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmp,
+ Src: linkLocalIPv6Addr2,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
- e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
// TestMGPDisabled tests that the multicast group protocol is not enabled by
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 57abec5c9..210262703 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "ports",
- srcs = ["ports.go"],
+ srcs = [
+ "flags.go",
+ "ports.go",
+ ],
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
diff --git a/pkg/tcpip/ports/flags.go b/pkg/tcpip/ports/flags.go
new file mode 100644
index 000000000..a8d7bff25
--- /dev/null
+++ b/pkg/tcpip/ports/flags.go
@@ -0,0 +1,150 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ports
+
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+
+ // TupleOnly represents TCP SO_REUSEADDR.
+ TupleOnly bool
+}
+
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
+ if f.MostRecent {
+ rf |= MostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= LoadBalancedFlag
+ }
+ if f.TupleOnly {
+ rf |= TupleOnlyFlag
+ }
+ return rf
+}
+
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
+
+const (
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // TupleOnlyFlag represents Flags.TupleOnly.
+ TupleOnlyFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
+ nextFlag
+
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
+
+ // MultiBindFlagMask contains the flags that allow binding the same
+ // tuple multiple times.
+ MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
+)
+
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ TupleOnly: f&TupleOnlyFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
+ refs [nextFlag]int
+}
+
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
+ var total int
+ for _, r := range c.refs {
+ total += r
+ }
+ return total
+}
+
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
+ var total int
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// SharedFlags returns the set of flags shared by all references.
+func (c FlagCounter) SharedFlags() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
+ if r > 0 {
+ intersection &= BitFlags(i)
+ }
+ }
+ return intersection
+}
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 11dbdbbcf..678199371 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+// Package ports provides PortManager that manages allocating, reserving and
+// releasing ports.
package ports
import (
- "math"
"math/rand"
"sync/atomic"
@@ -24,169 +24,44 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-const (
- // FirstEphemeral is the first ephemeral port.
- FirstEphemeral = 16000
+const anyIPAddress tcpip.Address = ""
- // numEphemeralPorts it the mnumber of available ephemeral ports to
- // Netstack.
- numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+// Reservation describes a port reservation.
+type Reservation struct {
+ // Networks is a list of network protocols to which the reservation
+ // applies. Can be IPv4, IPv6, or both.
+ Networks []tcpip.NetworkProtocolNumber
- anyIPAddress tcpip.Address = ""
-)
-
-type portDescriptor struct {
- network tcpip.NetworkProtocolNumber
- transport tcpip.TransportProtocolNumber
- port uint16
-}
-
-// Flags represents the type of port reservation.
-//
-// +stateify savable
-type Flags struct {
- // MostRecent represents UDP SO_REUSEADDR.
- MostRecent bool
-
- // LoadBalanced indicates SO_REUSEPORT.
- //
- // LoadBalanced takes precidence over MostRecent.
- LoadBalanced bool
-
- // TupleOnly represents TCP SO_REUSEADDR.
- TupleOnly bool
-}
-
-// Bits converts the Flags to their bitset form.
-func (f Flags) Bits() BitFlags {
- var rf BitFlags
- if f.MostRecent {
- rf |= MostRecentFlag
- }
- if f.LoadBalanced {
- rf |= LoadBalancedFlag
- }
- if f.TupleOnly {
- rf |= TupleOnlyFlag
- }
- return rf
-}
-
-// Effective returns the effective behavior of a flag config.
-func (f Flags) Effective() Flags {
- e := f
- if e.LoadBalanced && e.MostRecent {
- e.MostRecent = false
- }
- return e
-}
-
-// PortManager manages allocating, reserving and releasing ports.
-type PortManager struct {
- mu sync.RWMutex
- allocatedPorts map[portDescriptor]bindAddresses
-
- // hint is used to pick ports ephemeral ports in a stable order for
- // a given port offset.
- //
- // hint must be accessed using the portHint/incPortHint helpers.
- // TODO(gvisor.dev/issue/940): S/R this field.
- hint uint32
-}
-
-// BitFlags is a bitset representation of Flags.
-type BitFlags uint32
-
-const (
- // MostRecentFlag represents Flags.MostRecent.
- MostRecentFlag BitFlags = 1 << iota
-
- // LoadBalancedFlag represents Flags.LoadBalanced.
- LoadBalancedFlag
-
- // TupleOnlyFlag represents Flags.TupleOnly.
- TupleOnlyFlag
-
- // nextFlag is the value that the next added flag will have.
- //
- // It is used to calculate FlagMask below. It is also the number of
- // valid flag states.
- nextFlag
-
- // FlagMask is a bit mask for BitFlags.
- FlagMask = nextFlag - 1
+ // Transport is the transport protocol to which the reservation applies.
+ Transport tcpip.TransportProtocolNumber
- // MultiBindFlagMask contains the flags that allow binding the same
- // tuple multiple times.
- MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
-)
-
-// ToFlags converts the bitset into a Flags struct.
-func (f BitFlags) ToFlags() Flags {
- return Flags{
- MostRecent: f&MostRecentFlag != 0,
- LoadBalanced: f&LoadBalancedFlag != 0,
- TupleOnly: f&TupleOnlyFlag != 0,
- }
-}
+ // Addr is the address of the local endpoint.
+ Addr tcpip.Address
-// FlagCounter counts how many references each flag combination has.
-type FlagCounter struct {
- // refs stores the count for each possible flag combination, (0 though
- // FlagMask).
- refs [nextFlag]int
-}
+ // Port is the local port number.
+ Port uint16
-// AddRef increases the reference count for a specific flag combination.
-func (c *FlagCounter) AddRef(flags BitFlags) {
- c.refs[flags]++
-}
+ // Flags describe features of the reservation.
+ Flags Flags
-// DropRef decreases the reference count for a specific flag combination.
-func (c *FlagCounter) DropRef(flags BitFlags) {
- c.refs[flags]--
-}
+ // BindToDevice is the NIC to which the reservation applies.
+ BindToDevice tcpip.NICID
-// TotalRefs calculates the total number of references for all flag
-// combinations.
-func (c FlagCounter) TotalRefs() int {
- var total int
- for _, r := range c.refs {
- total += r
- }
- return total
+ // Dest is the destination address.
+ Dest tcpip.FullAddress
}
-// FlagRefs returns the number of references with all specified flags.
-func (c FlagCounter) FlagRefs(flags BitFlags) int {
- var total int
- for i, r := range c.refs {
- if BitFlags(i)&flags == flags {
- total += r
- }
- }
- return total
-}
-
-// AllRefsHave returns if all references have all specified flags.
-func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
- for i, r := range c.refs {
- if BitFlags(i)&flags != flags && r > 0 {
- return false
- }
+func (rs Reservation) dst() destination {
+ return destination{
+ rs.Dest.Addr,
+ rs.Dest.Port,
}
- return true
}
-// IntersectionRefs returns the set of flags shared by all references.
-func (c FlagCounter) IntersectionRefs() BitFlags {
- intersection := FlagMask
- for i, r := range c.refs {
- if r > 0 {
- intersection &= BitFlags(i)
- }
- }
- return intersection
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
}
type destination struct {
@@ -194,18 +69,14 @@ type destination struct {
port uint16
}
-func makeDestination(a tcpip.FullAddress) destination {
- return destination{
- a.Addr,
- a.Port,
- }
-}
-
-// portNode is never empty. When it has no elements, it is removed from the
-// map that references it.
-type portNode map[destination]FlagCounter
+// destToCounter maps each destination to the FlagCounter that represents
+// endpoints to that destination.
+//
+// destToCounter is never empty. When it has no elements, it is removed from
+// the map that references it.
+type destToCounter map[destination]FlagCounter
-// intersectionRefs calculates the intersection of flag bit values which affect
+// intersectionFlags calculates the intersection of flag bit values which affect
// the specified destination.
//
// If no destinations are present, all flag values are returned as there are no
@@ -213,20 +84,20 @@ type portNode map[destination]FlagCounter
//
// In addition to the intersection, the number of intersecting refs is
// returned.
-func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
+func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
intersection := FlagMask
var count int
- for d, f := range p {
- if d == dst {
- intersection &= f.IntersectionRefs()
+ for dest, counter := range dc {
+ if dest == res.dst() {
+ intersection &= counter.SharedFlags()
count++
continue
}
// Wildcard destinations affect all destinations for TupleOnly.
- if d.addr == anyIPAddress || dst.addr == anyIPAddress {
+ if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs())
+ intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
count++
}
}
@@ -234,27 +105,29 @@ func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
return intersection, count
}
-// deviceNode is never empty. When it has no elements, it is removed from the
+// deviceToDest maps NICs to destinations for which there are port reservations.
+//
+// deviceToDest is never empty. When it has no elements, it is removed from the
// map that references it.
-type deviceNode map[tcpip.NICID]portNode
+type deviceToDest map[tcpip.NICID]destToCounter
-// isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all FlagCounters. If binding to a specific device, check
-// against the unspecified device and the provided device.
+// isAvailable checks whether binding is possible by device. If not binding to
+// a device, check against all FlagCounters. If binding to a specific device,
+// check against the unspecified device and the provided device.
//
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- flagBits := flags.Bits()
- if bindToDevice == 0 {
+func (dd deviceToDest) isAvailable(res Reservation) bool {
+ flagBits := res.Flags.Bits()
+ if res.BindToDevice == 0 {
intersection := FlagMask
- for _, p := range d {
- i, c := p.intersectionRefs(dst)
- if c == 0 {
+ for _, dest := range dd {
+ flags, count := dest.intersectionFlags(res)
+ if count == 0 {
continue
}
- intersection &= i
+ intersection &= flags
if intersection&flagBits == 0 {
// Can't bind because the (addr,port) was
// previously bound without reuse.
@@ -266,18 +139,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti
intersection := FlagMask
- if p, ok := d[0]; ok {
- var c int
- intersection, c = p.intersectionRefs(dst)
- if c > 0 && intersection&flagBits == 0 {
+ if dests, ok := dd[0]; ok {
+ var count int
+ intersection, count = dests.intersectionFlags(res)
+ if count > 0 && intersection&flagBits == 0 {
return false
}
}
- if p, ok := d[bindToDevice]; ok {
- i, c := p.intersectionRefs(dst)
- intersection &= i
- if c > 0 && intersection&flagBits == 0 {
+ if dests, ok := dd[res.BindToDevice]; ok {
+ flags, count := dests.intersectionFlags(res)
+ intersection &= flags
+ if count > 0 && intersection&flagBits == 0 {
return false
}
}
@@ -285,18 +158,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti
return true
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]deviceNode
+// addrToDevice maps IP addresses to NICs that have port reservations.
+type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- if addr == anyIPAddress {
- // If binding to the "any" address then check that there are no conflicts
- // with all addresses.
- for _, d := range b {
- if !d.isAvailable(flags, bindToDevice, dst) {
+func (ad addrToDevice) isAvailable(res Reservation) bool {
+ if res.Addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no
+ // conflicts with all addresses.
+ for _, devices := range ad {
+ if !devices.isAvailable(res) {
return false
}
}
@@ -304,15 +177,15 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
}
// Check that there is no conflict with the "any" address.
- if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(flags, bindToDevice, dst) {
+ if devices, ok := ad[anyIPAddress]; ok {
+ if !devices.isAvailable(res) {
return false
}
}
// Check that this is no conflict with the provided address.
- if d, ok := b[addr]; ok {
- if !d.isAvailable(flags, bindToDevice, dst) {
+ if devices, ok := ad[res.Addr]; ok {
+ if !devices.isAvailable(res) {
return false
}
}
@@ -320,50 +193,93 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
return true
}
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ // mu protects allocatedPorts.
+ // LOCK ORDERING: mu > ephemeralMu.
+ mu sync.RWMutex
+ // allocatedPorts is a nesting of maps that ultimately map Reservations
+ // to FlagCounters describing whether the Reservation is valid and can
+ // be reused.
+ allocatedPorts map[portDescriptor]addrToDevice
+
+ // ephemeralMu protects firstEphemeral and numEphemeral.
+ ephemeralMu sync.RWMutex
+ firstEphemeral uint16
+ numEphemeral uint16
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
+}
+
// NewPortManager creates new PortManager.
func NewPortManager() *PortManager {
- return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
+ return &PortManager{
+ allocatedPorts: make(map[portDescriptor]addrToDevice),
+ // Match Linux's default ephemeral range. See:
+ // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842
+ firstEphemeral: 32768,
+ numEphemeral: 28232,
+ }
}
+// PortTester indicates whether the passed in port is suitable. Returning an
+// error causes the function to which the PortTester is passed to return that
+// error.
+type PortTester func(port uint16) (good bool, err tcpip.Error)
+
// PickEphemeralPort randomly chooses a starting point and iterates over all
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- offset := uint32(rand.Int31n(numEphemeralPorts))
- return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+ pm.ephemeralMu.RLock()
+ firstEphemeral := pm.firstEphemeral
+ numEphemeral := pm.numEphemeral
+ pm.ephemeralMu.RUnlock()
+
+ offset := uint16(rand.Int31n(int32(numEphemeral)))
+ return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
-// portHint atomically reads and returns the s.hint value.
-func (s *PortManager) portHint() uint32 {
- return atomic.LoadUint32(&s.hint)
+// portHint atomically reads and returns the pm.hint value.
+func (pm *PortManager) portHint() uint16 {
+ return uint16(atomic.LoadUint32(&pm.hint))
}
-// incPortHint atomically increments s.hint by 1.
-func (s *PortManager) incPortHint() {
- atomic.AddUint32(&s.hint, 1)
+// incPortHint atomically increments pm.hint by 1.
+func (pm *PortManager) incPortHint() {
+ atomic.AddUint32(&pm.hint, 1)
}
-// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// PickEphemeralPortStable starts at the specified offset + pm.portHint and
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ pm.ephemeralMu.RLock()
+ firstEphemeral := pm.firstEphemeral
+ numEphemeral := pm.numEphemeral
+ pm.ephemeralMu.RUnlock()
+
+ p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort)
if err == nil {
- s.incPortHint()
+ pm.incPortHint()
}
return p, err
-
}
// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- for i := uint32(0); i < count; i++ {
- port = uint16(FirstEphemeral + (offset+i)%count)
+func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ for i := uint16(0); i < count; i++ {
+ port = first + (offset+i)%count
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -377,144 +293,145 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
return 0, &tcpip.ErrNoPortAvailable{}
}
-// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest))
-}
-
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, flags, bindToDevice, dst) {
- return false
- }
- }
- }
- return true
-}
-
// ReservePort marks a port/IP combination as reserved so that it cannot be
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
//
-// An optional testPort closure can be passed in which if provided will be used
-// to test if the picked port can be used. The function should return true if
-// the port is safe to use, false otherwise.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- dst := makeDestination(dest)
+// An optional PortTester can be passed in which if provided will be used to
+// test if the picked port can be used. The function should return true if the
+// port is safe to use, false otherwise.
+func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
- if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
+ if res.Port != 0 {
+ if !pm.reserveSpecificPortLocked(res) {
return 0, &tcpip.ErrPortInUse{}
}
- if testPort != nil && !testPort(port) {
- s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
- return 0, &tcpip.ErrPortInUse{}
+ if testPort != nil {
+ ok, err := testPort(res.Port)
+ if err != nil {
+ pm.releasePortLocked(res)
+ return 0, err
+ }
+ if !ok {
+ pm.releasePortLocked(res)
+ return 0, &tcpip.ErrPortInUse{}
+ }
}
- return port, nil
+ return res.Port, nil
}
// A port wasn't specified, so try to find one.
- return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
- if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
+ return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ res.Port = p
+ if !pm.reserveSpecificPortLocked(res) {
return false, nil
}
- if testPort != nil && !testPort(p) {
- s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst)
- return false, nil
+ if testPort != nil {
+ ok, err := testPort(p)
+ if err != nil {
+ pm.releasePortLocked(res)
+ return false, err
+ }
+ if !ok {
+ pm.releasePortLocked(res)
+ return false, nil
+ }
}
return true, nil
})
}
-// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) {
- return false
+// reserveSpecificPortLocked tries to reserve the given port on all given
+// protocols.
+func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool {
+ // Make sure the port is available.
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ if addrs, ok := pm.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(res) {
+ return false
+ }
+ }
}
- flagBits := flags.Bits()
-
// Reserve port on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- m, ok := s.allocatedPorts[desc]
+ flagBits := res.Flags.Bits()
+ dst := res.dst()
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
- m = make(bindAddresses)
- s.allocatedPorts[desc] = m
+ addrToDev = make(addrToDevice)
+ pm.allocatedPorts[desc] = addrToDev
}
- d, ok := m[addr]
+ devToDest, ok := addrToDev[res.Addr]
if !ok {
- d = make(deviceNode)
- m[addr] = d
+ devToDest = make(deviceToDest)
+ addrToDev[res.Addr] = devToDest
}
- p := d[bindToDevice]
- if p == nil {
- p = make(portNode)
+ destToCntr := devToDest[res.BindToDevice]
+ if destToCntr == nil {
+ destToCntr = make(destToCounter)
}
- n := p[dst]
- n.AddRef(flagBits)
- p[dst] = n
- d[bindToDevice] = p
+ counter := destToCntr[dst]
+ counter.AddRef(flagBits)
+ destToCntr[dst] = counter
+ devToDest[res.BindToDevice] = destToCntr
}
return true
}
// ReserveTuple adds a port reservation for the tuple on all given protocol.
-func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
- flagBits := flags.Bits()
- dst := makeDestination(dest)
+func (pm *PortManager) ReserveTuple(res Reservation) bool {
+ flagBits := res.Flags.Bits()
+ dst := res.dst()
- s.mu.Lock()
- defer s.mu.Unlock()
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
// It is easier to undo the entire reservation, so if we find that the
// tuple can't be fully added, finish and undo the whole thing.
undo := false
// Reserve port on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- m, ok := s.allocatedPorts[desc]
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
- m = make(bindAddresses)
- s.allocatedPorts[desc] = m
+ addrToDev = make(addrToDevice)
+ pm.allocatedPorts[desc] = addrToDev
}
- d, ok := m[addr]
+ devToDest, ok := addrToDev[res.Addr]
if !ok {
- d = make(deviceNode)
- m[addr] = d
+ devToDest = make(deviceToDest)
+ addrToDev[res.Addr] = devToDest
}
- p := d[bindToDevice]
- if p == nil {
- p = make(portNode)
+ destToCntr := devToDest[res.BindToDevice]
+ if destToCntr == nil {
+ destToCntr = make(destToCounter)
}
- n := p[dst]
- if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 {
+ counter := destToCntr[dst]
+ if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 {
// Tuple already exists.
undo = true
}
- n.AddRef(flagBits)
- p[dst] = n
- d[bindToDevice] = p
+ counter.AddRef(flagBits)
+ destToCntr[dst] = counter
+ devToDest[res.BindToDevice] = destToCntr
}
if undo {
// releasePortLocked decrements the counts (rather than setting
// them to zero), so it will undo the incorrect incrementing
// above.
- s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst)
+ pm.releasePortLocked(res)
return false
}
@@ -523,47 +440,71 @@ func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, trans
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) {
- s.mu.Lock()
- defer s.mu.Unlock()
+func (pm *PortManager) ReleasePort(res Reservation) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
- s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest))
+ pm.releasePortLocked(res)
}
-func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) {
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if m, ok := s.allocatedPorts[desc]; ok {
- d, ok := m[addr]
- if !ok {
- continue
- }
- p, ok := d[bindToDevice]
- if !ok {
- continue
- }
- n, ok := p[dst]
- if !ok {
- continue
- }
- n.DropRef(flags)
- if n.TotalRefs() > 0 {
- p[dst] = n
- continue
- }
- delete(p, dst)
- if len(p) > 0 {
- continue
- }
- delete(d, bindToDevice)
- if len(d) > 0 {
- continue
- }
- delete(m, addr)
- if len(m) > 0 {
- continue
- }
- delete(s.allocatedPorts, desc)
+func (pm *PortManager) releasePortLocked(res Reservation) {
+ dst := res.dst()
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
+ if !ok {
+ continue
}
+ devToDest, ok := addrToDev[res.Addr]
+ if !ok {
+ continue
+ }
+ destToCounter, ok := devToDest[res.BindToDevice]
+ if !ok {
+ continue
+ }
+ counter, ok := destToCounter[dst]
+ if !ok {
+ continue
+ }
+ counter.DropRef(res.Flags.Bits())
+ if counter.TotalRefs() > 0 {
+ destToCounter[dst] = counter
+ continue
+ }
+ delete(destToCounter, dst)
+ if len(destToCounter) > 0 {
+ continue
+ }
+ delete(devToDest, res.BindToDevice)
+ if len(devToDest) > 0 {
+ continue
+ }
+ delete(addrToDev, res.Addr)
+ if len(addrToDev) > 0 {
+ continue
+ }
+ delete(pm.allocatedPorts, desc)
+ }
+}
+
+// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
+// both IPv4 and IPv6.
+func (pm *PortManager) PortRange() (uint16, uint16) {
+ pm.ephemeralMu.RLock()
+ defer pm.ephemeralMu.RUnlock()
+ return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1
+}
+
+// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
+// (inclusive).
+func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error {
+ if start > end {
+ return &tcpip.ErrInvalidPortRange{}
}
+ pm.ephemeralMu.Lock()
+ defer pm.ephemeralMu.Unlock()
+ pm.firstEphemeral = start
+ pm.numEphemeral = end - start + 1
+ return nil
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index e70fbb72b..0f43dc8f8 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -329,16 +329,35 @@ func TestPortReservation(t *testing.T) {
net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
for _, test := range test.actions {
+ first, _ := pm.PortRange()
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
+ portRes := Reservation{
+ Networks: net,
+ Transport: fakeTransNumber,
+ Addr: test.ip,
+ Port: test.port,
+ Flags: test.flags,
+ BindToDevice: test.device,
+ Dest: test.dest,
+ }
+ pm.ReleasePort(portRes)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */)
+ portRes := Reservation{
+ Networks: net,
+ Transport: fakeTransNumber,
+ Addr: test.ip,
+ Port: test.port,
+ Flags: test.flags,
+ BindToDevice: test.device,
+ Dest: test.dest,
+ }
+ gotPort, err := pm.ReservePort(portRes, nil /* testPort */)
if diff := cmp.Diff(test.want, err); diff != "" {
- t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff)
+ t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
}
- if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+ if test.port == 0 && (gotPort == 0 || gotPort < first) {
+ t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first)
}
}
})
@@ -346,6 +365,11 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
+ const (
+ firstEphemeral = 32000
+ numEphemeralPorts = 1000
+ )
+
for _, test := range []struct {
name string
f func(port uint16) (bool, tcpip.Error)
@@ -369,17 +393,17 @@ func TestPickEphemeralPort(t *testing.T) {
{
name: "only-port-16042-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port == FirstEphemeral+42 {
+ if port == firstEphemeral+42 {
return true, nil
}
return false, nil
},
- wantPort: FirstEphemeral + 42,
+ wantPort: firstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port < FirstEphemeral {
+ if port < firstEphemeral {
return true, nil
}
return false, nil
@@ -389,6 +413,9 @@ func TestPickEphemeralPort(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
+ if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
+ t.Fatalf("failed to set ephemeral port range: %s", err)
+ }
port, err := pm.PickEphemeralPort(test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
@@ -401,6 +428,11 @@ func TestPickEphemeralPort(t *testing.T) {
}
func TestPickEphemeralPortStable(t *testing.T) {
+ const (
+ firstEphemeral = 32000
+ numEphemeralPorts = 1000
+ )
+
for _, test := range []struct {
name string
f func(port uint16) (bool, tcpip.Error)
@@ -424,17 +456,17 @@ func TestPickEphemeralPortStable(t *testing.T) {
{
name: "only-port-16042-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port == FirstEphemeral+42 {
+ if port == firstEphemeral+42 {
return true, nil
}
return false, nil
},
- wantPort: FirstEphemeral + 42,
+ wantPort: firstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port < FirstEphemeral {
+ if port < firstEphemeral {
return true, nil
}
return false, nil
@@ -444,7 +476,10 @@ func TestPickEphemeralPortStable(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
- portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
+ t.Fatalf("failed to set ephemeral port range: %s", err)
+ }
+ portOffset := uint16(rand.Int31n(int32(numEphemeralPorts)))
port, err := pm.PickEphemeralPortStable(portOffset, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index cdb435644..3f083928f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data.Size())
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d63e9757c..0e8b90c9b 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 740bdac28..47796a6ba 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -99,12 +99,11 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi
}
// ndpDADEvent is a set of parameters that was passed to
-// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+// ndpDispatcher.OnDuplicateAddressDetectionResult.
type ndpDADEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- resolved bool
- err tcpip.Error
+ nicID tcpip.NICID
+ addr tcpip.Address
+ res stack.DADResult
}
type ndpRouterEvent struct {
@@ -173,14 +172,13 @@ type ndpDispatcher struct {
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
-// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) {
+// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
nicID,
addr,
- resolved,
- err,
+ res,
}
}
}
@@ -311,8 +309,8 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
-func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string {
- return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string {
+ return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e))
}
// TestDADDisabled tests that an address successfully resolves immediately
@@ -344,8 +342,8 @@ func TestDADDisabled(t *testing.T) {
// DAD on it.
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -491,8 +489,8 @@ func TestDADResolve(t *testing.T) {
case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
@@ -573,7 +571,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: header.IPv6Any,
+ Dst: snmc,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -594,9 +596,10 @@ func TestDADFail(t *testing.T) {
const nicID = 1
tests := []struct {
- name string
- rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
- getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ name string
+ rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
+ getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ expectedHolderLinkAddress tcpip.LinkAddress
}{
{
name: "RxSolicit",
@@ -604,6 +607,7 @@ func TestDADFail(t *testing.T) {
getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborSolicit
},
+ expectedHolderLinkAddress: "",
},
{
name: "RxAdvert",
@@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) {
na.Options().Serialize(header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(linkAddr1),
})
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: tgt,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -634,6 +642,7 @@ func TestDADFail(t *testing.T) {
getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborAdvert
},
+ expectedHolderLinkAddress: linkAddr1,
},
}
@@ -683,8 +692,8 @@ func TestDADFail(t *testing.T) {
// something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
@@ -782,8 +791,8 @@ func TestDADStop(t *testing.T) {
// time + extra 1s buffer, something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
@@ -844,8 +853,8 @@ func TestSetNDPConfigurations(t *testing.T) {
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatalf("expected DAD event for %s", addr)
@@ -936,8 +945,8 @@ func TestSetNDPConfigurations(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
@@ -973,7 +982,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
}
opts := ra.Options()
opts.Serialize(optSer)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: ip,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
@@ -1951,8 +1964,8 @@ func TestAutoGenTempAddr(t *testing.T) {
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -2157,8 +2170,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
}
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -2245,8 +2258,8 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// address to be generated.
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -2711,8 +2724,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
t.Helper()
clock.Advance(dupAddrTransmits * retransmitTimer)
- if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
@@ -2742,8 +2755,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
rxNDPSolicit(e, addr.Address)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -3841,26 +3854,26 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) {
+ expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
}
}
- expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -3917,7 +3930,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// generated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
- expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true)
+ expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
// The stable address will be assigned throughout the test.
return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
@@ -3992,7 +4005,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Simulate a DAD conflict.
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, &ndpDisp, addr.Address, false, nil)
+ expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
@@ -4002,7 +4015,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
}
- expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{})
+ expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{})
}
// Should not have any new addresses assigned to the NIC.
@@ -4015,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if maxRetries+1 > numFailures {
addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
- expectDADEventAsync(t, &ndpDisp, addr.Address, true)
+ expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{})
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -4132,8 +4145,8 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -4231,8 +4244,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -4243,8 +4256,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index afff1b434..48bb75e2f 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -59,21 +59,24 @@ const (
infiniteDuration = time.Duration(math.MaxInt64)
)
-// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
-// entries. The UpdatedAtNanos field is ignored due to a lack of a
-// deterministic method to predict the time that an event will be dispatched.
-func entryDiffOpts() []cmp.Option {
+// unorderedEventsDiffOpts returns options passed to cmp.Diff to sort slices of
+// events for cases where ordering must be ignored.
+func unorderedEventsDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
+ cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
+ }),
}
}
-// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
-// sort slices of entries for cases where ordering must be ignored.
-func entryDiffOptsWithSort() []cmp.Option {
- return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }))
+// unorderedEntriesDiffOpts returns options passed to cmp.Diff to sort slices of
+// entries for cases where ordering must be ignored.
+func unorderedEntriesDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
}
func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver {
@@ -280,48 +283,105 @@ func TestNeighborCacheSetConfig(t *testing.T) {
}
}
-func TestNeighborCacheEntry(t *testing.T) {
- c := DefaultNUDConfigurations()
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, c, clock)
+func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry, removed []NeighborEntry) error {
+ var gotLinkResolutionResult LinkResolutionResult
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
- }
- _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
+ gotLinkResolutionResult = r
+ })
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ return fmt.Errorf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
- clock.Advance(typicalLatency)
+ {
+ var wantEvents []testEntryEventInfo
- wantEvents := []testEntryEventInfo{
- {
+ for _, removedEntry := range removed {
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
+ })
+ }
+
+ wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAtNanos: clock.NowNanoseconds(),
},
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ })
+
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
+ nudDisp.mu.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ }
+
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-ch:
+ default:
+ return fmt.Errorf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
+ }
+ wantLinkResolutionResult := LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}
+ if diff := cmp.Diff(wantLinkResolutionResult, gotLinkResolutionResult); diff != "" {
+ return fmt.Errorf("got link resolution result mismatch (-want +got):\n%s", diff)
+ }
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
},
- },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
+ nudDisp.mu.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
}
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+
+ return nil
+}
+
+func addReachableEntry(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry) error {
+ return addReachableEntryWithRemoved(nudDisp, clock, linkRes, entry, nil /* removed */)
+}
+
+func TestNeighborCacheEntry(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ linkRes := newTestNeighborResolver(&nudDisp, c, clock)
+
+ entry, ok := linkRes.entries.entry(0)
+ if !ok {
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
+ }
+ if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil {
@@ -345,41 +405,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
}
-
- _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
-
- clock.Advance(typicalLatency)
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
linkRes.neigh.removeEntry(entry.Addr)
@@ -390,14 +419,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
},
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
nudDisp.mu.Unlock()
if diff != "" {
t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
@@ -439,18 +469,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
// Fill the neighbor cache to capacity to verify the LRU eviction strategy is
// working properly after the entry removal.
for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ {
- // Add a new entry
- entry, ok := c.linkRes.entries.entry(i)
- if !ok {
- return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i)
- }
- _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer)
-
- var wantEvents []testEntryEventInfo
+ var removedEntries []NeighborEntry
// When beyond the full capacity, the cache will evict an entry as per the
// LRU eviction strategy. Note that the number of static entries should not
@@ -458,63 +477,40 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if i >= neighborCacheSize+opts.startAtEntryIndex {
removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize)
if !ok {
- return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize)
+ return fmt.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize)
}
- wantEvents = append(wantEvents, testEntryEventInfo{
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
- },
- })
+ removedEntries = append(removedEntries, removedEntry)
}
- wantEvents = append(wantEvents, testEntryEventInfo{
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- }, testEntryEventInfo{
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- })
-
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ entry, ok := c.linkRes.entries.entry(i)
+ if !ok {
+ return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i)
+ }
+ if err := addReachableEntryWithRemoved(c.nudDisp, c.clock, c.linkRes, entry, removedEntries); err != nil {
+ return fmt.Errorf("addReachableEntryWithRemoved(...) = %s", err)
}
}
// Expect to find only the most recent entries. The order of entries reported
// by entries() is nondeterministic, so entries have to be sorted before
// comparison.
- wantUnsortedEntries := opts.wantStaticEntries
+ wantUnorderedEntries := opts.wantStaticEntries
for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ {
entry, ok := c.linkRes.entries.entry(i)
if !ok {
- return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i)
+ return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i)
}
+ durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos,
}
- wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ wantUnorderedEntries = append(wantUnorderedEntries, wantEntry)
}
- if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnorderedEntries, c.linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" {
return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
@@ -560,38 +556,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
- _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer)
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
// Remove the entry
@@ -603,14 +571,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
c.nudDisp.mu.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
@@ -636,33 +605,36 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// Add a static entry
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
staticLinkAddr := entry.LinkAddr + "static"
c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
+ },
},
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
+ c.nudDisp.mu.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
}
- // Remove the static entry that was just added
+ // Add a duplicate static entry with the same link address.
c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
- // No more events should have been dispatched.
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" {
@@ -680,48 +652,56 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
// Add a static entry
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
staticLinkAddr := entry.LinkAddr + "static"
c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
+ },
},
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
+ c.nudDisp.mu.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
}
// Add a duplicate entry with a different link address
staticLinkAddr += "duplicate"
c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+
{
wantEvents := []testEntryEventInfo{
{
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
c.nudDisp.mu.Lock()
- defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
+ c.nudDisp.mu.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
}
@@ -742,45 +722,51 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
// Add a static entry
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
staticLinkAddr := entry.LinkAddr + "static"
c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
+ },
},
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
+ c.nudDisp.mu.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
}
// Remove the static entry that was just added
c.linkRes.neigh.removeEntry(entry.Addr)
+
{
wantEvents := []testEntryEventInfo{
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
c.nudDisp.mu.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
@@ -812,66 +798,41 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
- }
- _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- c.clock.Advance(typicalLatency)
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
// Override the entry with a static one using the same address
staticLinkAddr := entry.LinkAddr + "static"
c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+
{
wantEvents := []testEntryEventInfo{
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
{
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
c.nudDisp.mu.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
@@ -883,9 +844,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
@@ -905,7 +867,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
@@ -913,40 +875,45 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
}
- if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
+ },
},
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
+ c.nudDisp.mu.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
}
opts := overflowOptions{
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
@@ -965,39 +932,10 @@ func TestNeighborCacheClear(t *testing.T) {
// Add a dynamic entry.
entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
}
- _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- clock.Advance(typicalLatency)
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
// Add a static entry.
@@ -1009,14 +947,15 @@ func TestNeighborCacheClear(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ UpdatedAtNanos: clock.NowNanoseconds(),
},
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
nudDisp.mu.events = nil
nudDisp.mu.Unlock()
if diff != "" {
@@ -1028,30 +967,32 @@ func TestNeighborCacheClear(t *testing.T) {
linkRes.neigh.clear()
// Remove events dispatched from clear() have no deterministic order so they
- // need to be sorted beforehand.
- wantUnsortedEvents := []testEntryEventInfo{
+ // need to be sorted before comparison.
+ wantUnorderedEvents := []testEntryEventInfo{
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
},
},
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ UpdatedAtNanos: clock.NowNanoseconds(),
},
},
}
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(wantUnsortedEvents, nudDisp.mu.events, eventDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnorderedEvents, nudDisp.mu.events, unorderedEventsDiffOpts()...); diff != "" {
t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1071,56 +1012,30 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.linkRes.entries.entry(0) not found")
- }
- _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- c.clock.Advance(typicalLatency)
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
+ t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
}
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
// Clear the cache.
c.linkRes.neigh.clear()
+
{
wantEvents := []testEntryEventInfo{
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: c.clock.NowNanoseconds(),
},
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
c.nudDisp.mu.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
@@ -1147,10 +1062,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
clock := faketime.NewManualClock()
linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- frequentlyUsedEntry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
- }
+ startedAt := clock.NowNanoseconds()
// The following logic is very similar to overflowCache, but
// periodically refreshes the frequently used entry.
@@ -1159,50 +1071,18 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
for i := 0; i < neighborCacheSize; i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Fatalf("linkRes.entries.entry(%d) not found", i)
+ t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- clock.Advance(typicalLatency)
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
- }
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
}
+ frequentlyUsedEntry, ok := linkRes.entries.entry(0)
+ if !ok {
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
+ }
+
// Keep adding more entries
for i := neighborCacheSize; i < linkRes.entries.size(); i++ {
// Periodically refresh the frequently used entry
@@ -1214,63 +1094,17 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Fatalf("linkRes.entries.entry(%d) not found", i)
- }
-
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- clock.Advance(typicalLatency)
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
// An entry should have been removed, as per the LRU eviction strategy
removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1)
if !ok {
- t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1)
- }
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
- },
- },
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
+ t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize+1)
}
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+
+ if err := addReachableEntryWithRemoved(&nudDisp, clock, linkRes, entry, []NeighborEntry{removedEntry}); err != nil {
+ t.Fatalf("addReachableEntryWithRemoved(...) = %s", err)
}
}
@@ -1282,23 +1116,27 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
Addr: frequentlyUsedEntry.Addr,
LinkAddr: frequentlyUsedEntry.LinkAddr,
State: Reachable,
+ // Can be inferred since the frequently used entry is the first to
+ // be created and transitioned to Reachable.
+ UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(),
},
}
for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Fatalf("linkRes.entries.entry(%d) not found", i)
- }
- wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos,
+ })
}
- if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" {
t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
@@ -1350,17 +1188,18 @@ func TestNeighborCacheConcurrent(t *testing.T) {
for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Errorf("linkRes.entries.entry(%d) not found", i)
+ t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- }
- wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos,
+ })
}
- if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" {
t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
}
@@ -1372,44 +1211,12 @@ func TestNeighborCacheReplace(t *testing.T) {
clock := faketime.NewManualClock()
linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- // Add an entry
entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
- }
-
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- clock.Advance(typicalLatency)
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
}
-
- // Verify the entry exists
- {
- e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
- }
- if t.Failed() {
- t.FailNow()
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- }
- if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
- t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
- }
+ if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
// Notify of a link address change
@@ -1417,7 +1224,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
entry, ok := linkRes.entries.entry(1)
if !ok {
- t.Fatal("linkRes.entries.entry(1) not found")
+ t.Fatal("got linkRes.entries.entry(1) = _, false, want = true")
}
updatedLinkAddr = entry.LinkAddr
}
@@ -1437,29 +1244,31 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
+ UpdatedAtNanos: clock.NowNanoseconds(),
}
- if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
- clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
+
// Verify that the neighbor is now reachable.
{
e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
}
- if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
@@ -1479,25 +1288,12 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
}
// First, sanity check that resolution is working
- {
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- clock.Advance(typicalLatency)
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
- }
+ if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
+ t.Fatalf("addReachableEntry(...) = %s", err)
}
got, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
@@ -1505,11 +1301,12 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
}
- if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" {
+ if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
@@ -1524,14 +1321,14 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
}
}
@@ -1555,7 +1352,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
}
_, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
@@ -1564,7 +1361,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1572,7 +1369,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
}
}
@@ -1580,14 +1377,15 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
// failing to perform address resolution.
func TestNeighborCacheRetryResolution(t *testing.T) {
config := DefaultNUDConfigurations()
+ nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(nil, config, clock)
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
// Simulate a faulty link.
linkRes.dropReplies = true
entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("linkRes.entries.entry(0) not found")
+ t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
}
// Perform address resolution with a faulty link, which will fail.
@@ -1598,27 +1396,75 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ }
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
+ nudDisp.mu.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
}
+
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
}
- }
- wantEntries := []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- },
- }
- if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
- t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Unreachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
+ nudDisp.mu.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ }
+
+ {
+ wantEntries := []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Unreachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
+ }
+ if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" {
+ t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+ }
}
// Retry address resolution with a working link.
@@ -1635,28 +1481,74 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
if incompleteEntry.State != Incomplete {
t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
}
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
+ nudDisp.mu.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ }
+
clock.Advance(typicalLatency)
select {
case <-ch:
- if !ok {
- t.Fatal("expected successful address resolution")
+ default:
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
+ }
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
+ },
+ },
}
- reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(wantEvents, nudDisp.mu.events)
+ nudDisp.mu.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
- if reachableEntry.Addr != entry.Addr {
- t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr)
+ }
+
+ {
+ gotEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
+ if err != nil {
+ t.Fatalf("linkRes.neigh.entry(%s, '', _): %s", entry.Addr, err)
}
- if reachableEntry.LinkAddr != entry.LinkAddr {
- t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr)
+
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAtNanos: clock.NowNanoseconds(),
}
- if reachableEntry.State != Reachable {
- t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String())
+ if diff := cmp.Diff(gotEntry, wantEntry); diff != "" {
+ t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff)
}
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
}
@@ -1674,7 +1566,7 @@ func BenchmarkCacheClear(b *testing.B) {
for i := 0; i < cacheSize; i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
- b.Fatalf("linkRes.entries.entry(%d) not found", i)
+ b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
_, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
@@ -1683,13 +1575,13 @@ func BenchmarkCacheClear(b *testing.B) {
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ b.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
select {
case <-ch:
default:
- b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index baae7dfe1..bb2b2d705 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -18,13 +18,11 @@ import (
"fmt"
"math"
"math/rand"
- "strings"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -52,23 +50,6 @@ func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
clock.Advance(immediateDuration)
}
-// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
-// The UpdatedAtNanos field is ignored due to a lack of a deterministic method
-// to predict the time that an event will be dispatched.
-func eventDiffOpts() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
- }
-}
-
-// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
-// sort slices of events for cases where ordering must be ignored.
-func eventDiffOptsWithSort() []cmp.Option {
- return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
- return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
- }))
-}
-
// The following unit tests exercise every state transition and verify its
// behavior with RFC 4681 and RFC 7048.
//
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f9323d545..62f7c880e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.mu.RUnlock()
n.stats.DisabledRx.Packets.Increment()
- n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
return
}
n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
@@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- transHeader, ok := pkt.Data.PullUp(8)
+ transHeader, ok := pkt.Data().PullUp(8)
if !ok {
return
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 4f013b212..8f288675d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -59,7 +59,7 @@ type PacketBuffer struct {
// PacketBuffers.
PacketBufferEntry
- // Data holds the payload of the packet.
+ // data holds the payload of the packet.
//
// For inbound packets, Data is initially the whole packet. Then gets moved to
// headers via PacketHeader.Consume, when the packet is being parsed.
@@ -69,7 +69,7 @@ type PacketBuffer struct {
//
// The bytes backing Data are immutable, a.k.a. users shouldn't write to its
// backing storage.
- Data buffer.VectorisedView
+ data buffer.VectorisedView
// headers stores metadata about each header.
headers [numHeaderType]headerInfo
@@ -127,7 +127,7 @@ type PacketBuffer struct {
// NewPacketBuffer creates a new PacketBuffer with opts.
func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
pk := &PacketBuffer{
- Data: opts.Data,
+ data: opts.Data,
}
if opts.ReserveHeaderBytes != 0 {
pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
@@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int {
// Size returns the size of packet in bytes.
func (pk *PacketBuffer) Size() int {
- return pk.HeaderSize() + pk.Data.Size()
+ return pk.HeaderSize() + pk.data.Size()
}
// MemSize returns the estimation size of the pk in memory, including backing
// buffer data.
func (pk *PacketBuffer) MemSize() int {
- return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize
+ return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize
+}
+
+// Data returns the handle to data portion of pk.
+func (pk *PacketBuffer) Data() PacketData {
+ return PacketData{pk: pk}
}
// Views returns the underlying storage of the whole packet.
@@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View {
}
}
- dataViews := pk.Data.Views()
+ dataViews := pk.data.Views()
var vs []buffer.View
if useHeader {
@@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
if h.buf != nil {
panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
}
- v, ok := pk.Data.PullUp(size)
+ v, ok := pk.data.PullUp(size)
if !ok {
return
}
- pk.Data.TrimFront(size)
+ pk.data.TrimFront(size)
h.buf = v
return h.buf, true
}
@@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
+ data: pk.data.Clone(nil),
headers: pk.headers,
header: pk.header,
Hash: pk.Hash,
@@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
return h.pk.consume(h.typ, size)
}
+// PacketData represents the data portion of a PacketBuffer.
+type PacketData struct {
+ pk *PacketBuffer
+}
+
+// PullUp returns a contiguous view of size bytes from the beginning of d.
+// Callers should not write to or keep the view for later use.
+func (d PacketData) PullUp(size int) (buffer.View, bool) {
+ return d.pk.data.PullUp(size)
+}
+
+// TrimFront removes count from the beginning of d. It panics if count >
+// d.Size().
+func (d PacketData) TrimFront(count int) {
+ d.pk.data.TrimFront(count)
+}
+
+// CapLength reduces d to at most length bytes.
+func (d PacketData) CapLength(length int) {
+ d.pk.data.CapLength(length)
+}
+
+// Views returns the underlying storage of d in a slice of Views. Caller should
+// not modify the returned slice.
+func (d PacketData) Views() []buffer.View {
+ return d.pk.data.Views()
+}
+
+// AppendView appends v into d, taking the ownership of v.
+func (d PacketData) AppendView(v buffer.View) {
+ d.pk.data.AppendView(v)
+}
+
+// ReadFromData moves at most count bytes from the beginning of srcData to the
+// end of d and returns the number of bytes moved.
+func (d PacketData) ReadFromData(srcData PacketData, count int) int {
+ return srcData.pk.data.ReadToVV(&d.pk.data, count)
+}
+
+// ReadFromVV moves at most count bytes from the beginning of srcVV to the end
+// of d and returns the number of bytes moved.
+func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int {
+ return srcVV.ReadToVV(&d.pk.data, count)
+}
+
+// Size returns the number of bytes in the data payload of the packet.
+func (d PacketData) Size() int {
+ return d.pk.data.Size()
+}
+
+// AsRange returns a Range representing the current data payload of the packet.
+func (d PacketData) AsRange() Range {
+ return Range{
+ pk: d.pk,
+ offset: d.pk.HeaderSize(),
+ length: d.Size(),
+ }
+}
+
+// ExtractVV returns a VectorisedView of d. This method has the semantic to
+// destruct the underlying packet, hence the packet cannot be used again.
+//
+// This method exists for compatibility between PacketBuffer and VectorisedView.
+// It may be removed later and should be used with care.
+func (d PacketData) ExtractVV() buffer.VectorisedView {
+ return d.pk.data
+}
+
+// Replace replaces the data portion of the packet with vv, taking the ownership
+// of vv.
+//
+// This method exists for compatibility between PacketBuffer and VectorisedView.
+// It may be removed later and should be used with care.
+func (d PacketData) Replace(vv buffer.VectorisedView) {
+ d.pk.data = vv
+}
+
+// Range represents a contiguous subportion of a PacketBuffer.
+type Range struct {
+ pk *PacketBuffer
+ offset int
+ length int
+}
+
+// Size returns the number of bytes in r.
+func (r Range) Size() int {
+ return r.length
+}
+
+// SubRange returns a new Range starting at off bytes of r. It returns an empty
+// range if off is out-of-bounds.
+func (r Range) SubRange(off int) Range {
+ if off > r.length {
+ return Range{pk: r.pk}
+ }
+ return Range{
+ pk: r.pk,
+ offset: r.offset + off,
+ length: r.length - off,
+ }
+}
+
+// Capped returns a new Range with the same starting point of r and length
+// capped at max.
+func (r Range) Capped(max int) Range {
+ if r.length <= max {
+ return r
+ }
+ return Range{
+ pk: r.pk,
+ offset: r.offset,
+ length: max,
+ }
+}
+
+// AsView returns the backing storage of r if possible. It will allocate a new
+// View if r spans multiple pieces internally. Caller should not write to the
+// returned View in any way.
+func (r Range) AsView() buffer.View {
+ var allocated bool
+ var v buffer.View
+ r.iterate(func(b []byte) {
+ if v == nil {
+ // v has not been assigned, allowing first view to be returned.
+ v = b
+ } else {
+ // v has been assigned. This range spans more than a view, a new view
+ // needs to be allocated.
+ if !allocated {
+ allocated = true
+ all := make([]byte, 0, r.length)
+ all = append(all, v...)
+ v = all
+ }
+ v = append(v, b...)
+ }
+ })
+ return v
+}
+
+// ToOwnedView returns a owned copy of data in r.
+func (r Range) ToOwnedView() buffer.View {
+ if r.length == 0 {
+ return nil
+ }
+ all := make([]byte, 0, r.length)
+ r.iterate(func(b []byte) {
+ all = append(all, b...)
+ })
+ return all
+}
+
+// Checksum calculates the RFC 1071 checksum for the underlying bytes of r.
+func (r Range) Checksum() uint16 {
+ var c header.Checksumer
+ r.iterate(c.Add)
+ return c.Checksum()
+}
+
+// iterate calls fn for each piece in r. fn is always called with a non-empty
+// slice.
+func (r Range) iterate(fn func([]byte)) {
+ w := window{
+ offset: r.offset,
+ length: r.length,
+ }
+ // Header portion.
+ for i := range r.pk.headers {
+ if b := w.process(r.pk.headers[i].buf); len(b) > 0 {
+ fn(b)
+ }
+ if w.isDone() {
+ break
+ }
+ }
+ // Data portion.
+ if !w.isDone() {
+ for _, v := range r.pk.data.Views() {
+ if b := w.process(v); len(b) > 0 {
+ fn(b)
+ }
+ if w.isDone() {
+ break
+ }
+ }
+ }
+}
+
+// window represents contiguous region of byte stream. User would call process()
+// to input bytes, and obtain a subslice that is inside the window.
+type window struct {
+ offset int
+ length int
+}
+
+// isDone returns true if the window has passed and further process() calls will
+// always return an empty slice. This can be used to end processing early.
+func (w *window) isDone() bool {
+ return w.length == 0
+}
+
+// process feeds b in and returns a subslice that is inside the window. The
+// returned slice will be a subslice of b, and it does not keep b after method
+// returns. This method may return an empty slice if nothing in b is inside the
+// window.
+func (w *window) process(b []byte) (inWindow []byte) {
+ if w.offset >= len(b) {
+ w.offset -= len(b)
+ return nil
+ }
+ if w.offset > 0 {
+ b = b[w.offset:]
+ w.offset = 0
+ }
+ if w.length < len(b) {
+ b = b[:w.length]
+ }
+ w.length -= len(b)
+ return b
+}
+
// PayloadSince returns packet payload starting from and including a particular
// header.
//
// The returned View is owned by the caller - its backing buffer is separate
// from the packet header's underlying packet buffer.
func PayloadSince(h PacketHeader) buffer.View {
- size := h.pk.Data.Size()
+ size := h.pk.data.Size()
for _, hinfo := range h.pk.headers[h.typ:] {
size += len(hinfo.buf)
}
@@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View {
v = append(v, hinfo.buf...)
}
- for _, view := range h.pk.Data.Views() {
+ for _, view := range h.pk.data.Views() {
v = append(v, view...)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index c6fa8da5f..6728370c3 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -15,9 +15,11 @@ package stack
import (
"bytes"
+ "fmt"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
func TestPacketHeaderPush(t *testing.T) {
@@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data)
+ checkData(t, pk, test.data)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
concatViews(test.link, test.network, test.transport, test.data))
// Check the after values for each header.
@@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) {
transport = test.data[test.link+test.network:][:test.transport]
payload = test.data[allHdrSize:]
)
- checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload)
+ checkData(t, pk, payload)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
// Check the after values for each header.
checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
@@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
}
}
+func TestPacketBufferData(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ makePkt func(*testing.T) *PacketBuffer
+ data string
+ }{
+ {
+ name: "inbound packet",
+ makePkt: func(*testing.T) *PacketBuffer {
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: vv("aabbbbccccccDATA"),
+ })
+ pkt.LinkHeader().Consume(2)
+ pkt.NetworkHeader().Consume(4)
+ pkt.TransportHeader().Consume(6)
+ return pkt
+ },
+ data: "DATA",
+ },
+ {
+ name: "outbound packet",
+ makePkt: func(*testing.T) *PacketBuffer {
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: 12,
+ Data: vv("DATA"),
+ })
+ copy(pkt.TransportHeader().Push(6), []byte("cccccc"))
+ copy(pkt.NetworkHeader().Push(4), []byte("bbbb"))
+ copy(pkt.LinkHeader().Push(2), []byte("aa"))
+ return pkt
+ },
+ data: "DATA",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // PullUp
+ for _, n := range []int{1, len(tc.data)} {
+ t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ v, ok := pkt.Data().PullUp(n)
+ wantV := []byte(tc.data)[:n]
+ if !ok || !bytes.Equal(v, wantV) {
+ t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV)
+ }
+ })
+ }
+ t.Run("PullUpOutOfBounds", func(t *testing.T) {
+ n := len(tc.data) + 1
+ pkt := tc.makePkt(t)
+ v, ok := pkt.Data().PullUp(n)
+ if ok || v != nil {
+ t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok)
+ }
+ })
+
+ // TrimFront
+ for _, n := range []int{1, len(tc.data)} {
+ t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ pkt.Data().TrimFront(n)
+
+ checkData(t, pkt, []byte(tc.data)[n:])
+ })
+ }
+
+ // CapLength
+ for _, n := range []int{0, 1, len(tc.data)} {
+ t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ pkt.Data().CapLength(n)
+
+ want := []byte(tc.data)
+ if n < len(want) {
+ want = want[:n]
+ }
+ checkData(t, pkt, want)
+ })
+ }
+
+ // Views
+ t.Run("Views", func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ checkData(t, pkt, []byte(tc.data))
+ })
+
+ // AppendView
+ t.Run("AppendView", func(t *testing.T) {
+ s := "APPEND"
+
+ pkt := tc.makePkt(t)
+ pkt.Data().AppendView(buffer.View(s))
+
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+
+ // ReadFromData/VV
+ for _, n := range []int{0, 1, 2, 7, 10, 14, 20} {
+ t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) {
+ s := "TO READ"
+ otherPkt := NewPacketBuffer(PacketBufferOptions{
+ Data: vv(s, s),
+ })
+ s += s
+
+ pkt := tc.makePkt(t)
+ pkt.Data().ReadFromData(otherPkt.Data(), n)
+
+ if n < len(s) {
+ s = s[:n]
+ }
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+ t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) {
+ s := "TO READ"
+ srcVV := vv(s, s)
+ s += s
+
+ pkt := tc.makePkt(t)
+ pkt.Data().ReadFromVV(&srcVV, n)
+
+ if n < len(s) {
+ s = s[:n]
+ }
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+ }
+
+ // ExtractVV
+ t.Run("ExtractVV", func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ extractedVV := pkt.Data().ExtractVV()
+
+ got := extractedVV.ToOwnedView()
+ want := []byte(tc.data)
+ if !bytes.Equal(got, want) {
+ t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want)
+ }
+ })
+
+ // Replace
+ t.Run("Replace", func(t *testing.T) {
+ s := "REPLACED"
+
+ pkt := tc.makePkt(t)
+ pkt.Data().Replace(vv(s))
+
+ checkData(t, pkt, []byte(s))
+ })
+ })
+ }
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data)
+ checkData(t, pk, data)
checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
// Check the initial values for each header.
checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
@@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
}
}
+func checkData(t *testing.T, pkt *PacketBuffer, want []byte) {
+ t.Helper()
+ if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) {
+ t.Errorf("pkt.Data().Views() = %x, want %x", got, want)
+ }
+ if got := pkt.Data().Size(); got != len(want) {
+ t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want))
+ }
+
+ t.Run("AsRange", func(t *testing.T) {
+ // Full range
+ checkRange(t, pkt.Data().AsRange(), want)
+
+ // SubRange
+ for _, off := range []int{0, 1, len(want), len(want) + 1} {
+ t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) {
+ // Empty when off is greater than the size of range.
+ var sub []byte
+ if off < len(want) {
+ sub = want[off:]
+ }
+ checkRange(t, pkt.Data().AsRange().SubRange(off), sub)
+ })
+ }
+
+ // Capped
+ for _, n := range []int{0, 1, len(want), len(want) + 1} {
+ t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) {
+ sub := want
+ if n < len(sub) {
+ sub = sub[:n]
+ }
+ checkRange(t, pkt.Data().AsRange().Capped(n), sub)
+ })
+ }
+ })
+}
+
+func checkRange(t *testing.T, r Range, data []byte) {
+ if got, want := r.Size(), len(data); got != want {
+ t.Errorf("r.Size() = %d, want %d", got, want)
+ }
+ if got := r.AsView(); !bytes.Equal(got, data) {
+ t.Errorf("r.AsView() = %x, want %x", got, data)
+ }
+ if got := r.ToOwnedView(); !bytes.Equal(got, data) {
+ t.Errorf("r.ToOwnedView() = %x, want %x", got, data)
+ }
+ if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want {
+ t.Errorf("r.Checksum() = %x, want %x", got, want)
+ }
+}
+
+func vv(pieces ...string) buffer.VectorisedView {
+ var views []buffer.View
+ var size int
+ for _, p := range pieces {
+ v := buffer.View([]byte(p))
+ size += len(v)
+ views = append(views, v)
+ }
+ return buffer.NewVectorisedView(size, views)
+}
+
func makeView(size int) buffer.View {
b := byte(size)
return bytes.Repeat([]byte{b}, size)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 43e9e4beb..85f0f471a 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -852,18 +852,46 @@ type InjectableLinkEndpoint interface {
InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error
}
-// DADResult is the result of a duplicate address detection process.
-type DADResult struct {
- // Resolved is true when DAD completed without detecting a duplicate address
- // on the link.
- //
- // Ignored when Err is non-nil.
- Resolved bool
+// DADResult is a marker interface for the result of a duplicate address
+// detection process.
+type DADResult interface {
+ isDADResult()
+}
+
+var _ DADResult = (*DADSucceeded)(nil)
+
+// DADSucceeded indicates DAD completed without finding any duplicate addresses.
+type DADSucceeded struct{}
- // Err is an error encountered while performing DAD.
+func (*DADSucceeded) isDADResult() {}
+
+var _ DADResult = (*DADError)(nil)
+
+// DADError indicates DAD hit an error.
+type DADError struct {
Err tcpip.Error
}
+func (*DADError) isDADResult() {}
+
+var _ DADResult = (*DADAborted)(nil)
+
+// DADAborted indicates DAD was aborted.
+type DADAborted struct{}
+
+func (*DADAborted) isDADResult() {}
+
+var _ DADResult = (*DADDupAddrDetected)(nil)
+
+// DADDupAddrDetected indicates DAD detected a duplicate address.
+type DADDupAddrDetected struct {
+ // HolderLinkAddress is the link address of the node that holds the duplicate
+ // address.
+ HolderLinkAddress tcpip.LinkAddress
+}
+
+func (*DADDupAddrDetected) isDADResult() {}
+
// DADCompletionHandler is a handler for DAD completion.
type DADCompletionHandler func(DADResult)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index de94ddfda..53370c354 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -813,6 +813,18 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
return forwardingProtocol.Forwarding()
}
+// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
+// both IPv4 and IPv6.
+func (s *Stack) PortRange() (uint16, uint16) {
+ return s.PortManager.PortRange()
+}
+
+// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
+// (inclusive).
+func (s *Stack) SetPortRange(start uint16, end uint16) tcpip.Error {
+ return s.PortManager.SetPortRange(start, end)
+}
+
// SetRouteTable assigns the route table to be used by this stack. It
// specifies which NIC to use for given destination address ranges.
//
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8e39e828c..880219007 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data.TrimFront(fakeNetHeaderLen)
+ pkt.Data().TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
@@ -2605,7 +2605,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" {
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
@@ -3289,7 +3289,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
+ if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
@@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) {
if pkt.Route.RemoteLinkAddress != linkAddr2 {
t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
}
- if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" {
t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
}
})
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e799f9290..e188efccb 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -359,7 +359,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[0]
}
- if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
return mpep.endpoints[len(mpep.endpoints)-1]
}
@@ -410,7 +410,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
return &tcpip.ErrPortInUse{}
}
}
@@ -429,7 +429,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
return &tcpip.ErrPortInUse{}
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 01a4389e3..87ea09a5e 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1258,44 +1258,38 @@ func (m *MultiCounterStat) IncrementBy(v uint64) {
type ICMPv4PacketStats struct {
// LINT.IfChange(ICMPv4PacketStats)
- // Echo is the total number of ICMPv4 echo packets counted.
- Echo *StatCounter
+ // EchoRequest is the number of ICMPv4 echo packets counted.
+ EchoRequest *StatCounter
- // EchoReply is the total number of ICMPv4 echo reply packets counted.
+ // EchoReply is the number of ICMPv4 echo reply packets counted.
EchoReply *StatCounter
- // DstUnreachable is the total number of ICMPv4 destination unreachable
- // packets counted.
+ // DstUnreachable is the number of ICMPv4 destination unreachable packets
+ // counted.
DstUnreachable *StatCounter
- // SrcQuench is the total number of ICMPv4 source quench packets
- // counted.
+ // SrcQuench is the number of ICMPv4 source quench packets counted.
SrcQuench *StatCounter
- // Redirect is the total number of ICMPv4 redirect packets counted.
+ // Redirect is the number of ICMPv4 redirect packets counted.
Redirect *StatCounter
- // TimeExceeded is the total number of ICMPv4 time exceeded packets
- // counted.
+ // TimeExceeded is the number of ICMPv4 time exceeded packets counted.
TimeExceeded *StatCounter
- // ParamProblem is the total number of ICMPv4 parameter problem packets
- // counted.
+ // ParamProblem is the number of ICMPv4 parameter problem packets counted.
ParamProblem *StatCounter
- // Timestamp is the total number of ICMPv4 timestamp packets counted.
+ // Timestamp is the number of ICMPv4 timestamp packets counted.
Timestamp *StatCounter
- // TimestampReply is the total number of ICMPv4 timestamp reply packets
- // counted.
+ // TimestampReply is the number of ICMPv4 timestamp reply packets counted.
TimestampReply *StatCounter
- // InfoRequest is the total number of ICMPv4 information request
- // packets counted.
+ // InfoRequest is the number of ICMPv4 information request packets counted.
InfoRequest *StatCounter
- // InfoReply is the total number of ICMPv4 information reply packets
- // counted.
+ // InfoReply is the number of ICMPv4 information reply packets counted.
InfoReply *StatCounter
// LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4PacketStats)
@@ -1307,12 +1301,11 @@ type ICMPv4SentPacketStats struct {
ICMPv4PacketStats
- // Dropped is the total number of ICMPv4 packets dropped due to link
- // layer errors.
+ // Dropped is the number of ICMPv4 packets dropped due to link layer errors.
Dropped *StatCounter
- // RateLimited is the total number of ICMPv4 packets dropped due to
- // rate limit being exceeded.
+ // RateLimited is the number of ICMPv4 packets dropped due to rate limit being
+ // exceeded.
RateLimited *StatCounter
// LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4SentPacketStats)
@@ -1324,7 +1317,7 @@ type ICMPv4ReceivedPacketStats struct {
ICMPv4PacketStats
- // Invalid is the total number of invalid ICMPv4 packets received.
+ // Invalid is the number of invalid ICMPv4 packets received.
Invalid *StatCounter
// LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4ReceivedPacketStats)
@@ -1347,59 +1340,50 @@ type ICMPv4Stats struct {
type ICMPv6PacketStats struct {
// LINT.IfChange(ICMPv6PacketStats)
- // EchoRequest is the total number of ICMPv6 echo request packets
- // counted.
+ // EchoRequest is the number of ICMPv6 echo request packets counted.
EchoRequest *StatCounter
- // EchoReply is the total number of ICMPv6 echo reply packets counted.
+ // EchoReply is the number of ICMPv6 echo reply packets counted.
EchoReply *StatCounter
- // DstUnreachable is the total number of ICMPv6 destination unreachable
- // packets counted.
+ // DstUnreachable is the number of ICMPv6 destination unreachable packets
+ // counted.
DstUnreachable *StatCounter
- // PacketTooBig is the total number of ICMPv6 packet too big packets
- // counted.
+ // PacketTooBig is the number of ICMPv6 packet too big packets counted.
PacketTooBig *StatCounter
- // TimeExceeded is the total number of ICMPv6 time exceeded packets
- // counted.
+ // TimeExceeded is the number of ICMPv6 time exceeded packets counted.
TimeExceeded *StatCounter
- // ParamProblem is the total number of ICMPv6 parameter problem packets
- // counted.
+ // ParamProblem is the number of ICMPv6 parameter problem packets counted.
ParamProblem *StatCounter
- // RouterSolicit is the total number of ICMPv6 router solicit packets
- // counted.
+ // RouterSolicit is the number of ICMPv6 router solicit packets counted.
RouterSolicit *StatCounter
- // RouterAdvert is the total number of ICMPv6 router advert packets
- // counted.
+ // RouterAdvert is the number of ICMPv6 router advert packets counted.
RouterAdvert *StatCounter
- // NeighborSolicit is the total number of ICMPv6 neighbor solicit
- // packets counted.
+ // NeighborSolicit is the number of ICMPv6 neighbor solicit packets counted.
NeighborSolicit *StatCounter
- // NeighborAdvert is the total number of ICMPv6 neighbor advert packets
- // counted.
+ // NeighborAdvert is the number of ICMPv6 neighbor advert packets counted.
NeighborAdvert *StatCounter
- // RedirectMsg is the total number of ICMPv6 redirect message packets
- // counted.
+ // RedirectMsg is the number of ICMPv6 redirect message packets counted.
RedirectMsg *StatCounter
- // MulticastListenerQuery is the total number of Multicast Listener Query
- // messages counted.
+ // MulticastListenerQuery is the number of Multicast Listener Query messages
+ // counted.
MulticastListenerQuery *StatCounter
- // MulticastListenerReport is the total number of Multicast Listener Report
- // messages counted.
+ // MulticastListenerReport is the number of Multicast Listener Report messages
+ // counted.
MulticastListenerReport *StatCounter
- // MulticastListenerDone is the total number of Multicast Listener Done
- // messages counted.
+ // MulticastListenerDone is the number of Multicast Listener Done messages
+ // counted.
MulticastListenerDone *StatCounter
// LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6PacketStats)
@@ -1411,12 +1395,11 @@ type ICMPv6SentPacketStats struct {
ICMPv6PacketStats
- // Dropped is the total number of ICMPv6 packets dropped due to link
- // layer errors.
+ // Dropped is the number of ICMPv6 packets dropped due to link layer errors.
Dropped *StatCounter
- // RateLimited is the total number of ICMPv6 packets dropped due to
- // rate limit being exceeded.
+ // RateLimited is the number of ICMPv6 packets dropped due to rate limit being
+ // exceeded.
RateLimited *StatCounter
// LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6SentPacketStats)
@@ -1428,15 +1411,15 @@ type ICMPv6ReceivedPacketStats struct {
ICMPv6PacketStats
- // Unrecognized is the total number of ICMPv6 packets received that the
- // transport layer does not know how to parse.
+ // Unrecognized is the number of ICMPv6 packets received that the transport
+ // layer does not know how to parse.
Unrecognized *StatCounter
- // Invalid is the total number of invalid ICMPv6 packets received.
+ // Invalid is the number of invalid ICMPv6 packets received.
Invalid *StatCounter
- // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets
- // dropped due to being router-specific packets.
+ // RouterOnlyPacketsDroppedByHost is the number of ICMPv6 packets dropped due
+ // to being router-specific packets.
RouterOnlyPacketsDroppedByHost *StatCounter
// LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6ReceivedPacketStats)
@@ -1468,18 +1451,18 @@ type ICMPStats struct {
type IGMPPacketStats struct {
// LINT.IfChange(IGMPPacketStats)
- // MembershipQuery is the total number of Membership Query messages counted.
+ // MembershipQuery is the number of Membership Query messages counted.
MembershipQuery *StatCounter
- // V1MembershipReport is the total number of Version 1 Membership Report
- // messages counted.
+ // V1MembershipReport is the number of Version 1 Membership Report messages
+ // counted.
V1MembershipReport *StatCounter
- // V2MembershipReport is the total number of Version 2 Membership Report
- // messages counted.
+ // V2MembershipReport is the number of Version 2 Membership Report messages
+ // counted.
V2MembershipReport *StatCounter
- // LeaveGroup is the total number of Leave Group messages counted.
+ // LeaveGroup is the number of Leave Group messages counted.
LeaveGroup *StatCounter
// LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPPacketStats)
@@ -1491,7 +1474,7 @@ type IGMPSentPacketStats struct {
IGMPPacketStats
- // Dropped is the total number of IGMP packets dropped.
+ // Dropped is the number of IGMP packets dropped.
Dropped *StatCounter
// LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPSentPacketStats)
@@ -1503,15 +1486,14 @@ type IGMPReceivedPacketStats struct {
IGMPPacketStats
- // Invalid is the total number of invalid IGMP packets received.
+ // Invalid is the number of invalid IGMP packets received.
Invalid *StatCounter
- // ChecksumErrors is the total number of IGMP packets dropped due to bad
- // checksums.
+ // ChecksumErrors is the number of IGMP packets dropped due to bad checksums.
ChecksumErrors *StatCounter
- // Unrecognized is the total number of unrecognized messages counted, these
- // are silently ignored for forward-compatibilty.
+ // Unrecognized is the number of unrecognized messages counted, these are
+ // silently ignored for forward-compatibilty.
Unrecognized *StatCounter
// LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPReceivedPacketStats)
@@ -1534,51 +1516,50 @@ type IGMPStats struct {
type IPStats struct {
// LINT.IfChange(IPStats)
- // PacketsReceived is the total number of IP packets received from the
- // link layer.
+ // PacketsReceived is the number of IP packets received from the link layer.
PacketsReceived *StatCounter
- // DisabledPacketsReceived is the total number of IP packets received from the
- // link layer when the IP layer is disabled.
+ // DisabledPacketsReceived is the number of IP packets received from the link
+ // layer when the IP layer is disabled.
DisabledPacketsReceived *StatCounter
- // InvalidDestinationAddressesReceived is the total number of IP packets
- // received with an unknown or invalid destination address.
+ // InvalidDestinationAddressesReceived is the number of IP packets received
+ // with an unknown or invalid destination address.
InvalidDestinationAddressesReceived *StatCounter
- // InvalidSourceAddressesReceived is the total number of IP packets received
- // with a source address that should never have been received on the wire.
+ // InvalidSourceAddressesReceived is the number of IP packets received with a
+ // source address that should never have been received on the wire.
InvalidSourceAddressesReceived *StatCounter
- // PacketsDelivered is the total number of incoming IP packets that
- // are successfully delivered to the transport layer.
+ // PacketsDelivered is the number of incoming IP packets that are successfully
+ // delivered to the transport layer.
PacketsDelivered *StatCounter
- // PacketsSent is the total number of IP packets sent via WritePacket.
+ // PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent *StatCounter
- // OutgoingPacketErrors is the total number of IP packets which failed
- // to write to a link-layer endpoint.
+ // OutgoingPacketErrors is the number of IP packets which failed to write to a
+ // link-layer endpoint.
OutgoingPacketErrors *StatCounter
- // MalformedPacketsReceived is the total number of IP Packets that were
- // dropped due to the IP packet header failing validation checks.
+ // MalformedPacketsReceived is the number of IP Packets that were dropped due
+ // to the IP packet header failing validation checks.
MalformedPacketsReceived *StatCounter
- // MalformedFragmentsReceived is the total number of IP Fragments that were
- // dropped due to the fragment failing validation checks.
+ // MalformedFragmentsReceived is the number of IP Fragments that were dropped
+ // due to the fragment failing validation checks.
MalformedFragmentsReceived *StatCounter
- // IPTablesPreroutingDropped is the total number of IP packets dropped
- // in the Prerouting chain.
+ // IPTablesPreroutingDropped is the number of IP packets dropped in the
+ // Prerouting chain.
IPTablesPreroutingDropped *StatCounter
- // IPTablesInputDropped is the total number of IP packets dropped in
- // the Input chain.
+ // IPTablesInputDropped is the number of IP packets dropped in the Input
+ // chain.
IPTablesInputDropped *StatCounter
- // IPTablesOutputDropped is the total number of IP packets dropped in
- // the Output chain.
+ // IPTablesOutputDropped is the number of IP packets dropped in the Output
+ // chain.
IPTablesOutputDropped *StatCounter
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 0cb9d034e..38c2f321b 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -135,14 +135,15 @@ func TestForwarding(t *testing.T) {
name string
proto tcpip.TransportProtocolNumber
expectedConnectErr tcpip.Error
- setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
needRemoteAddr bool
}{
{
name: "UDP",
proto: udp.ProtocolNumber,
expectedConnectErr: nil,
- setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
t.Helper()
if err := ep.Connect(clientAddr); err != nil {
@@ -156,12 +157,16 @@ func TestForwarding(t *testing.T) {
name: "TCP",
proto: tcp.ProtocolNumber,
expectedConnectErr: &tcpip.ErrConnectStarted{},
- setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
t.Helper()
if err := ep.Listen(1); err != nil {
t.Fatalf("ep.Listen(1): %s", err)
}
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
var addr tcpip.FullAddress
for {
newEP, wq, err := ep.Accept(&addr)
@@ -214,6 +219,9 @@ func TestForwarding(t *testing.T) {
t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
}
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
{
err := epsAndAddrs.clientEP.Connect(serverAddr)
if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
@@ -229,7 +237,7 @@ func TestForwarding(t *testing.T) {
serverEP := epsAndAddrs.serverEP
serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil {
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil {
defer ep.Close()
serverEP = ep
serverCH = ch
@@ -256,13 +264,20 @@ func TestForwarding(t *testing.T) {
read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
t.Helper()
- // Wait for the endpoint to be readable.
- <-ch
var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
}
readResult := tcpip.ReadResult{
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 165f73f21..095623789 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -675,9 +675,7 @@ func TestWritePacketsLinkResolution(t *testing.T) {
Length: length,
})
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
pkts.PushBack(pkt)
@@ -1169,53 +1167,53 @@ func TestDAD(t *testing.T) {
}
tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- dadNetProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedResolved bool
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ dadNetProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedResult stack.DADResult
}{
{
- name: "IPv4 own address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv4 own address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
{
- name: "IPv6 own address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv6 own address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
{
- name: "IPv4 duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedResolved: false,
+ name: "IPv4 duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
},
{
- name: "IPv6 duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedResolved: false,
+ name: "IPv6 duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
},
{
- name: "IPv4 no duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv4 no duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
{
- name: "IPv6 no duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv6 no duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
}
@@ -1262,7 +1260,7 @@ func TestDAD(t *testing.T) {
}
expectResults := 1
- if test.expectedResolved {
+ if _, ok := test.expectedResult.(*stack.DADSucceeded); ok {
const delta = time.Nanosecond
clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
select {
@@ -1287,7 +1285,7 @@ func TestDAD(t *testing.T) {
}
for i := 0; i < expectResults; i++ {
- if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" {
+ if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" {
t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
}
}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index c56155ea2..80afc2825 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -38,7 +38,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
type ndpDispatcher struct{}
-func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
+func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index e4439ba79..29266a4fc 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -75,7 +75,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetCode(0)
pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, utils.RemoteIPv6Addr, dst, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: utils.RemoteIPv6Addr,
+ Dst: dst,
+ }))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: header.ICMPv6MinimumSize,
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index f5e1a6e45..06c63e74a 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// TODO(https://gvisor.dev/issues/5623): Unit test this package.
+
// +stateify savable
type icmpPacket struct {
icmpPacketEntry
@@ -414,15 +416,27 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
+
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
- pkt.Data = data.ToVectorisedView()
+ pkt.Data().AppendView(data)
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
+
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
+ r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ return err
+ }
+
+ sentStat.Increment()
+ return nil
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
@@ -444,15 +458,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
-
- dataVV := data.ToVectorisedView()
- icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
- pkt.Data = dataVV
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
+
+ pkt.Data().AppendView(data)
+ dataRange := pkt.Data().AsRange()
+ icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpv6,
+ Src: r.LocalAddress,
+ Dst: r.RemoteAddress,
+ PayloadCsum: dataRange.Checksum(),
+ PayloadLen: dataRange.Size(),
+ }))
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
+
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
+ r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ }
+
+ sentStat.Increment()
+ return nil
}
// checkV4MappedLocked determines the effective network protocol and converts
@@ -763,7 +793,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// ICMP socket's data includes ICMP header.
packet.data = pkt.TransportHeader().View().ToVectorisedView()
- packet.data.Append(pkt.Data)
+ packet.data.Append(pkt.Data().ExtractVV())
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 73bb66830..367757d3b 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -432,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Cooked packets can simply be queued.
switch pkt.PktType {
case tcpip.PacketHost:
- packet.data = pkt.Data
+ packet.data = pkt.Data().ExtractVV()
case tcpip.PacketOutgoing:
// Strip Link Header.
var combinedVV buffer.VectorisedView
@@ -442,7 +442,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
if v := pkt.TransportHeader().View(); !v.IsEmpty() {
combinedVV.AppendView(v)
}
- combinedVV.Append(pkt.Data)
+ combinedVV.Append(pkt.Data().ExtractVV())
packet.data = combinedVV
default:
panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
@@ -468,7 +468,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
}
combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data)
+ combinedVV.Append(pkt.Data().ExtractVV())
packet.data = combinedVV
} else {
packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index fe8e9c751..2709be90c 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -644,7 +644,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
} else {
combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
}
- combinedVV.Append(pkt.Data)
+ combinedVV.Append(pkt.Data().ExtractVV())
packet.data = combinedVV
packet.timestampNS = e.stack.Clock().NowNanoseconds()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index fcdd032c5..a69d6624d 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -105,7 +105,6 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 842c1622b..3b574837c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -432,15 +433,16 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// * e.mu is held.
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
- if !e.stack.ReserveTuple(
- e.effectiveNetProtos,
- ProtocolNumber,
- e.ID.LocalAddress,
- e.ID.LocalPort,
- e.boundPortFlags,
- e.boundBindToDevice,
- dest,
- ) {
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: dest,
+ }
+ if !e.stack.ReserveTuple(portRes) {
return false
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 461b1a9d7..3404af6bb 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -68,7 +68,7 @@ type handshake struct {
ep *endpoint
state handshakeState
active bool
- flags uint8
+ flags header.TCPFlags
ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
@@ -606,7 +606,7 @@ func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer
func (bt *backoffTimer) reset() tcpip.Error {
bt.timeout *= 2
- if bt.timeout > MaxRTO {
+ if bt.timeout > bt.maxTimeout {
return &tcpip.ErrTimeout{}
}
bt.t.Reset(bt.timeout)
@@ -700,7 +700,7 @@ type tcpFields struct {
id stack.TransportEndpointID
ttl uint8
tos uint8
- flags byte
+ flags header.TCPFlags
seq seqnum.Value
ack seqnum.Value
rcvWnd seqnum.Size
@@ -752,7 +752,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
}
@@ -786,7 +786,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
})
pkt.Hash = tf.txHash
pkt.Owner = owner
- data.ReadToVV(&pkt.Data, packetSize)
+ pkt.Data().ReadFromVV(&data, packetSize)
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
pkts.PushBack(pkt)
@@ -877,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
}
// sendRaw sends a TCP segment to the endpoint's peer.
-func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
var sackBlocks []header.SACKBlock
if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f47b39ccc..129f36d11 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -760,6 +760,7 @@ func (e *endpoint) LockUser() {
// protocol goroutine altogether.
//
// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
+// +checklocks:e.mu
func (e *endpoint) UnlockUser() {
// Lock segment queue before checking so that we avoid a race where
// segments can be queued between the time we check if queue is empty
@@ -800,6 +801,7 @@ func (e *endpoint) StopWork() {
}
// ResumeWork resumes packet processing. Only to be used in tests.
+// +checklocks:e.mu
func (e *endpoint) ResumeWork() {
e.mu.Unlock()
}
@@ -1095,7 +1097,16 @@ func (e *endpoint) closeNoShutdownLocked() {
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: e.boundDest,
+ }
+ e.stack.ReleasePort(portRes)
e.isPortReserved = false
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
@@ -1170,7 +1181,16 @@ func (e *endpoint) cleanupLocked() {
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: e.boundDest,
+ }
+ e.stack.ReleasePort(portRes)
e.isPortReserved = false
}
e.boundBindToDevice = 0
@@ -2218,7 +2238,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
h.Write(portBuf)
- portOffset := h.Sum32()
+ portOffset := uint16(h.Sum32())
var twReuse tcpip.TCPTimeWaitReuseOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
@@ -2240,7 +2260,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: p,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: addr,
+ }
+ if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
@@ -2278,7 +2307,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: p,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: addr,
+ }
+ if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2286,7 +2324,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
id := e.ID
id.LocalPort = p
if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: p,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: addr,
+ }
+ e.stack.ReleasePort(portRes)
if _, ok := err.(*tcpip.ErrPortInUse); ok {
return false, nil
}
@@ -2602,7 +2649,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: addr.Addr,
+ Port: addr.Port,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2614,9 +2670,9 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
- return false
+ return false, nil
}
- return true
+ return true, nil
})
if err != nil {
return err
@@ -2699,7 +2755,7 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
Cause: transErr,
// Linux passes the payload with the TCP header. We don't know if the TCP
// header even exists, it may not for fragmented packets.
- Payload: pkt.Data.ToView(),
+ Payload: pkt.Data().AsRange().ToOwnedView(),
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: e.ID.RemoteAddress,
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index e4368026f..a53d76917 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -22,9 +22,11 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// +checklocks:e.mu
func (e *endpoint) drainSegmentLocked() {
// Drain only up to once.
if e.drainDone != nil {
@@ -207,7 +209,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
- if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: addr.Addr,
+ Port: addr.Port,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: e.boundDest,
+ }
+ if ok := e.stack.ReserveTuple(portRes); !ok {
panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
e.isPortReserved = true
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 04012cd40..2a4667906 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -226,7 +226,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
- flags := byte(header.TCPFlagRst)
+ flags := header.TCPFlagRst
// As per RFC 793 page 35 (Reset Generation)
// 1. If the connection does not exist (CLOSED) then a reset is sent
// in response to any incoming segment except another reset. In
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index f27eef6a9..8edd6775b 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -62,7 +62,7 @@ type segment struct {
views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
- flags uint8
+ flags header.TCPFlags
window seqnum.Size
// csum is only populated for received segments.
csum uint16
@@ -98,7 +98,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
netProto: pkt.NetworkProtocolNumber,
nicID: pkt.NICID,
}
- s.data = pkt.Data.Clone(s.views[:])
+ s.data = pkt.Data().ExtractVV().Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
s.dataMemSize = s.data.Size()
@@ -141,12 +141,12 @@ func (s *segment) clone() *segment {
}
// flagIsSet checks if at least one flag in flags is set in s.flags.
-func (s *segment) flagIsSet(flags uint8) bool {
+func (s *segment) flagIsSet(flags header.TCPFlags) bool {
return s.flags&flags != 0
}
// flagsAreSet checks if all flags in flags are set in s.flags.
-func (s *segment) flagsAreSet(flags uint8) bool {
+func (s *segment) flagsAreSet(flags header.TCPFlags) bool {
return s.flags&flags == flags
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 83c8deb0e..18817029d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1613,7 +1613,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
-func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error {
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0128c1f7e..fd499a47b 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -33,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -1373,7 +1372,7 @@ func TestTOSV4(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
)
@@ -1421,7 +1420,7 @@ func TestTrafficClassV6(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
)
@@ -2202,7 +2201,7 @@ func TestSimpleSend(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2242,7 +2241,7 @@ func TestZeroWindowSend(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2264,7 +2263,7 @@ func TestZeroWindowSend(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2311,7 +2310,7 @@ func TestScaledWindowConnect(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2342,7 +2341,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2415,7 +2414,7 @@ func TestScaledWindowAccept(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2488,7 +2487,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2666,7 +2665,7 @@ func TestSegmentMerging(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2689,7 +2688,7 @@ func TestSegmentMerging(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+11),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2738,7 +2737,7 @@ func TestDelay(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2786,7 +2785,7 @@ func TestUndelay(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2809,7 +2808,7 @@ func TestUndelay(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2872,7 +2871,7 @@ func TestMSSNotDelayed(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2923,7 +2922,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3438,7 +3437,7 @@ func TestMaxRTO(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
const numRetransmits = 2
@@ -3447,7 +3446,7 @@ func TestMaxRTO(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
@@ -3490,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
checker.FragmentFlags(0),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}}
@@ -3502,7 +3501,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
checker.FragmentFlags(0),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
id := header.IPv4(pkt).ID()
@@ -3633,7 +3632,7 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3710,7 +3709,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3729,7 +3728,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3796,7 +3795,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3822,7 +3821,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3886,7 +3885,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3907,7 +3906,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(791),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3923,7 +3922,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(791),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -4033,7 +4032,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -4783,7 +4782,8 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("unknown address type: '%s'", candidateAddressType)
}
- for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
+ start, end := s.PortRange()
+ for i := start; i <= end; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
@@ -4844,7 +4844,7 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(seqNum),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
seqNum += uint32(size)
@@ -5129,7 +5129,7 @@ func TestKeepalive(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -7174,7 +7174,7 @@ func TestTCPCloseWithData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -7274,7 +7274,7 @@ func TestTCPUserTimeout(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 5a9745ad7..cb4f82903 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -170,7 +170,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
)
@@ -231,7 +231,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
checker.TCPTimestampChecker(false, 0, 0),
),
)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index b1cb9a324..2f1c1011d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -101,7 +101,7 @@ type Headers struct {
AckNum seqnum.Value
// Flags are the TCP flags in the TCP header.
- Flags int
+ Flags header.TCPFlags
// RcvWnd is the window to be advertised in the ReceiveWindow field of
// the TCP header.
@@ -452,7 +452,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
SeqNum: uint32(h.SeqNum),
AckNum: uint32(h.AckNum),
DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
- Flags: uint8(h.Flags),
+ Flags: h.Flags,
WindowSize: uint16(h.RcvWnd),
})
@@ -544,7 +544,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -571,7 +571,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -650,7 +650,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
SeqNum: uint32(h.SeqNum),
AckNum: uint32(h.AckNum),
DataOffset: header.TCPMinimumSize,
- Flags: uint8(h.Flags),
+ Flags: h.Flags,
WindowSize: uint16(h.RcvWnd),
})
@@ -780,7 +780,7 @@ type RawEndpoint struct {
C *Context
SrcPort uint16
DstPort uint16
- Flags int
+ Flags header.TCPFlags
NextSeqNum seqnum.Value
AckNum seqnum.Value
WndSize seqnum.Size
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
index 5e271b7ca..6c5ddc3c7 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -465,7 +465,7 @@ func TestIgnoreBadResetOnSynSent(t *testing.T) {
// Receive a RST with a bad ACK, it should not cause the connection to
// be reset.
acks := []uint32{1234, 1236, 1000, 5000}
- flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
+ flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
for _, a := range acks {
for _, f := range flags {
tcp.Encode(&header.TCPFields{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 807df2bb5..c0f566459 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -245,7 +245,16 @@ func (e *endpoint) Close() {
switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
}
@@ -920,7 +929,16 @@ func (e *endpoint) Disconnect() tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
e.setEndpointState(StateInitial)
@@ -1072,7 +1090,16 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ port, err := e.stack.ReservePort(portRes, nil /* testPort */)
if err != nil {
return id, bindToDevice, err
}
@@ -1082,7 +1109,16 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: bindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
return id, bindToDevice, err
@@ -1227,7 +1263,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
(hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
netHdr := pkt.Network()
xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
- for _, v := range pkt.Data.Views() {
+ for _, v := range pkt.Data().Views() {
xsum = header.Checksum(v, xsum)
}
return hdr.CalculateChecksum(xsum) == 0xffff
@@ -1240,7 +1276,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
- if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
+ if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -1287,10 +1323,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
Addr: id.LocalAddress,
Port: header.UDP(hdr).DestinationPort(),
},
+ data: pkt.Data().ExtractVV(),
}
- packet.data = pkt.Data
e.rcvList.PushBack(packet)
- e.rcvBufSize += pkt.Data.Size()
+ e.rcvBufSize += packet.data.Size()
// Save any useful information from the network header to the packet.
switch pkt.NetworkProtocolNumber {
@@ -1327,7 +1363,7 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
if e.SocketOptions().GetRecvError() {
// Linux passes the payload without the UDP header.
var payload []byte
- udp := header.UDP(pkt.Data.ToView())
+ udp := header.UDP(pkt.Data().AsRange().ToOwnedView())
if len(udp) >= header.UDPMinimumSize {
payload = udp.Payload()
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 427fdd0c9..1171aeb79 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -80,7 +80,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
// protocol but don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
hdr := header.UDP(pkt.TransportHeader().View())
- if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
+ if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
p.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return stack.UnknownDestinationPacketMalformed
}