summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go39
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go43
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go30
-rw-r--r--pkg/sentry/kernel/signal.go3
-rw-r--r--pkg/sentry/vfs/device.go29
-rw-r--r--pkg/sentry/vfs/vfs.go18
-rw-r--r--pkg/tcpip/header/checksum.go131
-rw-r--r--pkg/tcpip/header/checksum_test.go62
-rw-r--r--pkg/tcpip/stack/ndp_test.go47
-rw-r--r--pkg/tcpip/stack/nic.go3
11 files changed, 352 insertions, 54 deletions
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 7601c7c04..691476b4f 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -20,6 +20,7 @@ go_library(
name = "tmpfs",
srcs = [
"dentry_list.go",
+ "device_file.go",
"directory.go",
"filesystem.go",
"named_pipe.go",
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
new file mode 100644
index 000000000..84b181b90
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type deviceFile struct {
+ inode inode
+ kind vfs.DeviceKind
+ major uint32
+ minor uint32
+}
+
+func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
+ file := &deviceFile{
+ kind: kind,
+ major: major,
+ minor: minor,
+ }
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index a9f66a42a..d726f03c5 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -228,23 +228,26 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
return fs.doCreateAt(rp, false /* dir */, func(parent *dentry, name string) error {
+ var childInode *inode
switch opts.Mode.FileType() {
case 0, linux.S_IFREG:
- child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
- parent.vfsd.InsertChild(&child.vfsd, name)
- parent.inode.impl.(*directory).childList.PushBack(child)
- return nil
+ childInode = fs.newRegularFile(rp.Credentials(), opts.Mode)
case linux.S_IFIFO:
- child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
- parent.vfsd.InsertChild(&child.vfsd, name)
- parent.inode.impl.(*directory).childList.PushBack(child)
- return nil
- case linux.S_IFBLK, linux.S_IFCHR, linux.S_IFSOCK:
+ childInode = fs.newNamedPipe(rp.Credentials(), opts.Mode)
+ case linux.S_IFBLK:
+ childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
+ case linux.S_IFCHR:
+ childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
+ case linux.S_IFSOCK:
// Not yet supported.
return syserror.EPERM
default:
return syserror.EINVAL
}
+ child := fs.newDentry(childInode)
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ parent.inode.impl.(*directory).childList.PushBack(child)
+ return nil
})
}
@@ -264,7 +267,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return nil, err
}
- return d.open(ctx, rp, opts.Flags, false /* afterCreate */)
+ return d.open(ctx, rp, &opts, false /* afterCreate */)
}
mustCreate := opts.Flags&linux.O_EXCL != 0
@@ -279,7 +282,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- return start.open(ctx, rp, opts.Flags, false /* afterCreate */)
+ return start.open(ctx, rp, &opts, false /* afterCreate */)
}
afterTrailingSymlink:
parent, err := walkParentDirLocked(rp, start)
@@ -313,7 +316,7 @@ afterTrailingSymlink:
child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
parent.vfsd.InsertChild(&child.vfsd, name)
parent.inode.impl.(*directory).childList.PushBack(child)
- return child.open(ctx, rp, opts.Flags, true)
+ return child.open(ctx, rp, &opts, true)
}
if err != nil {
return nil, err
@@ -327,11 +330,11 @@ afterTrailingSymlink:
if mustCreate {
return nil, syserror.EEXIST
}
- return child.open(ctx, rp, opts.Flags, false)
+ return child.open(ctx, rp, &opts, false)
}
-func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
+func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts.Flags)
if !afterCreate {
if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
return nil, err
@@ -340,10 +343,10 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32,
switch impl := d.inode.impl.(type) {
case *regularFile:
var fd regularFileFD
- if err := fd.vfsfd.Init(&fd, flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
- if flags&linux.O_TRUNC != 0 {
+ if opts.Flags&linux.O_TRUNC != 0 {
impl.mu.Lock()
impl.data.Truncate(0, impl.memFile)
atomic.StoreUint64(&impl.size, 0)
@@ -356,7 +359,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32,
return nil, syserror.EISDIR
}
var fd directoryFD
- if err := fd.vfsfd.Init(&fd, flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -364,7 +367,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, flags uint32,
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
case *namedPipe:
- return newNamedPipeFD(ctx, impl, rp, &d.vfsd, flags)
+ return newNamedPipeFD(ctx, impl, rp, &d.vfsd, opts.Flags)
+ case *deviceFile:
+ return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
default:
panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl))
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 1d4889c89..515f033f2 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -149,6 +149,10 @@ type inode struct {
ctime int64 // nanoseconds
mtime int64 // nanoseconds
+ // Only meaningful for device special files.
+ rdevMajor uint32
+ rdevMinor uint32
+
impl interface{} // immutable
}
@@ -269,6 +273,15 @@ func (i *inode) statTo(stat *linux.Statx) {
stat.Blocks = allocatedBlocksForSize(stat.Size)
case *namedPipe:
stat.Mode |= linux.S_IFIFO
+ case *deviceFile:
+ switch impl.kind {
+ case vfs.BlockDevice:
+ stat.Mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ stat.Mode |= linux.S_IFCHR
+ }
+ stat.RdevMajor = impl.major
+ stat.RdevMinor = impl.minor
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
@@ -309,12 +322,8 @@ func (i *inode) setStat(stat linux.Statx) error {
}
case *directory:
return syserror.EISDIR
- case *symlink:
- return syserror.EINVAL
- case *namedPipe:
- // Nothing.
default:
- panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ return syserror.EINVAL
}
}
if mask&linux.STATX_ATIME != 0 {
@@ -353,13 +362,22 @@ func allocatedBlocksForSize(size uint64) uint64 {
}
func (i *inode) direntType() uint8 {
- switch i.impl.(type) {
+ switch impl := i.impl.(type) {
case *regularFile:
return linux.DT_REG
case *directory:
return linux.DT_DIR
case *symlink:
return linux.DT_LNK
+ case *deviceFile:
+ switch impl.kind {
+ case vfs.BlockDevice:
+ return linux.DT_BLK
+ case vfs.CharDevice:
+ return linux.DT_CHR
+ default:
+ panic(fmt.Sprintf("unknown vfs.DeviceKind: %v", impl.kind))
+ }
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index 02eede93d..e8cce37d0 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -38,6 +38,9 @@ const SignalPanic = linux.SIGUSR2
// Preconditions: Kernel must have an init process.
func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
switch linux.Signal(info.Signo) {
+ case linux.SIGURG:
+ // Sent by the Go 1.14+ runtime for asynchronous goroutine preemption.
+
case platform.SignalInterrupt:
// Assume that a call to platform.Context.Interrupt() misfired.
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
index cb672e36f..9f9d6e783 100644
--- a/pkg/sentry/vfs/device.go
+++ b/pkg/sentry/vfs/device.go
@@ -98,3 +98,32 @@ func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mo
}
return rd.dev.Open(ctx, mnt, d, *opts)
}
+
+// GetAnonBlockDevMinor allocates and returns an unused minor device number for
+// an "anonymous" block device with major number 0.
+func (vfs *VirtualFilesystem) GetAnonBlockDevMinor() (uint32, error) {
+ vfs.anonBlockDevMinorMu.Lock()
+ defer vfs.anonBlockDevMinorMu.Unlock()
+ minor := vfs.anonBlockDevMinorNext
+ const maxDevMinor = (1 << 20) - 1
+ for minor < maxDevMinor {
+ if _, ok := vfs.anonBlockDevMinor[minor]; !ok {
+ vfs.anonBlockDevMinor[minor] = struct{}{}
+ vfs.anonBlockDevMinorNext = minor + 1
+ return minor, nil
+ }
+ minor++
+ }
+ return 0, syserror.EMFILE
+}
+
+// PutAnonBlockDevMinor deallocates a minor device number returned by a
+// previous call to GetAnonBlockDevMinor.
+func (vfs *VirtualFilesystem) PutAnonBlockDevMinor(minor uint32) {
+ vfs.anonBlockDevMinorMu.Lock()
+ defer vfs.anonBlockDevMinorMu.Unlock()
+ delete(vfs.anonBlockDevMinor, minor)
+ if minor < vfs.anonBlockDevMinorNext {
+ vfs.anonBlockDevMinorNext = minor
+ }
+}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 1f21b0b31..1f6f56293 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -80,6 +80,14 @@ type VirtualFilesystem struct {
devicesMu sync.RWMutex
devices map[devTuple]*registeredDevice
+ // anonBlockDevMinor contains all allocated anonymous block device minor
+ // numbers. anonBlockDevMinorNext is a lower bound for the smallest
+ // unallocated anonymous block device number. anonBlockDevMinorNext and
+ // anonBlockDevMinor are protected by anonBlockDevMinorMu.
+ anonBlockDevMinorMu sync.Mutex
+ anonBlockDevMinorNext uint32
+ anonBlockDevMinor map[uint32]struct{}
+
// fsTypes contains all registered FilesystemTypes. fsTypes is protected by
// fsTypesMu.
fsTypesMu sync.RWMutex
@@ -94,10 +102,12 @@ type VirtualFilesystem struct {
// New returns a new VirtualFilesystem with no mounts or FilesystemTypes.
func New() *VirtualFilesystem {
vfs := &VirtualFilesystem{
- mountpoints: make(map[*Dentry]map[*Mount]struct{}),
- devices: make(map[devTuple]*registeredDevice),
- fsTypes: make(map[string]*registeredFilesystemType),
- filesystems: make(map[*Filesystem]struct{}),
+ mountpoints: make(map[*Dentry]map[*Mount]struct{}),
+ devices: make(map[devTuple]*registeredDevice),
+ anonBlockDevMinorNext: 1,
+ anonBlockDevMinor: make(map[uint32]struct{}),
+ fsTypes: make(map[string]*registeredFilesystemType),
+ filesystems: make(map[*Filesystem]struct{}),
}
vfs.mounts.Init()
return vfs
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 9749c7f4d..204285576 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
+func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
+ v := initial
+
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
+ l := len(buf)
+ odd = l&1 != 0
+ if odd {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+ for (l - 64) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[64:]
+ l = l - 64
+ }
+ if (l - 32) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[32:]
+ l = l - 32
+ }
+ if (l - 16) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[16:]
+ l = l - 16
+ }
+ if (l - 8) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ buf = buf[8:]
+ l = l - 8
+ }
+ if (l - 4) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ buf = buf[4:]
+ l = l - 4
+ }
+
+ // At this point since l was even before we started unrolling
+ // there can be only two bytes left to add.
+ if l != 0 {
+ v += (uint32(buf[0]) << 8) + uint32(buf[1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
+}
+
+// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given byte array. This function uses a non-optimized implementation. Its
+// only retained for reference and to use as a benchmark/test. Most code should
+// use the header.Checksum function.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumOld(buf []byte, initial uint16) uint16 {
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
+}
+
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
-// given byte array.
+// given byte array. This function uses an optimized unrolled version of the
+// checksum algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- s, _ := calculateChecksum(buf, false, uint32(initial))
+ s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
return s
}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 86b466c1c..309403482 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -17,6 +17,8 @@
package header_test
import (
+ "fmt"
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) {
})
}
}
+
+func TestChecksum(t *testing.T) {
+ var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
+ type testCase struct {
+ buf []byte
+ initial uint16
+ csumOrig uint16
+ csumNew uint16
+ }
+ testCases := make([]testCase, 100000)
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for i := range testCases {
+ testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
+ testCases[i].initial = uint16(rnd.Intn(65536))
+ rnd.Read(testCases[i].buf)
+ }
+
+ for i := range testCases {
+ testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
+ if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
+ t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
+ }
+ }
+}
+
+func BenchmarkChecksum(b *testing.B) {
+ var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
+
+ checkSumImpls := []struct {
+ fn func([]byte, uint16) uint16
+ name string
+ }{
+ {header.ChecksumOld, fmt.Sprintf("checksum_old")},
+ {header.Checksum, fmt.Sprintf("checksum")},
+ }
+
+ for _, csumImpl := range checkSumImpls {
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for _, bufSz := range bufSizes {
+ b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
+ tc := struct {
+ buf []byte
+ initial uint16
+ csum uint16
+ }{
+ buf: make([]byte, bufSz),
+ initial: uint16(rnd.Intn(65536)),
+ }
+ rnd.Read(tc.buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tc.csum = csumImpl.fn(tc.buf, tc.initial)
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 376681b30..f9460bd51 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -35,13 +35,14 @@ import (
)
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- defaultTimeout = 100 * time.Millisecond
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ defaultTimeout = 100 * time.Millisecond
+ defaultAsyncEventTimeout = time.Second
)
var (
@@ -1086,7 +1087,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1103,7 +1104,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1342,7 +1343,7 @@ func TestPrefixDiscovery(t *testing.T) {
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1681,7 +1682,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1987,7 +1988,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2027,7 +2028,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2041,7 +2042,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2073,7 +2074,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -2088,7 +2089,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -2213,7 +2214,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -2701,7 +2702,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3325,12 +3326,12 @@ func TestRouterSolicitation(t *testing.T) {
// times.
remaining := test.maxRtrSolicit
if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
remaining--
}
for ; remaining > 0; remaining-- {
waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
- waitForPkt(2 * defaultTimeout)
+ waitForPkt(defaultAsyncEventTimeout)
}
// Make sure no more RS.
@@ -3411,9 +3412,9 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Disable forwarding which should start router solicitations.
s.SetForwarding(false)
- waitForPkt(delay + defaultTimeout)
- waitForPkt(interval + defaultTimeout)
- waitForPkt(interval + defaultTimeout)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
select {
case <-e.C:
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 79556a36f..7dad9a8cb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1208,6 +1208,9 @@ func (n *NIC) Stack() *Stack {
// false. It will only return true if the address is associated with the NIC
// AND it is tentative.
func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return false