diff options
157 files changed, 4416 insertions, 811 deletions
diff --git a/kokoro/runtime_tests.cfg b/kokoro/runtime_tests.cfg new file mode 100644 index 000000000..7d56d5aca --- /dev/null +++ b/kokoro/runtime_tests.cfg @@ -0,0 +1 @@ +build_file: "repo/scripts/runtime_tests.sh" diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 7c17109a6..51774c6b6 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -23,6 +23,8 @@ go_library( "exec.go", "fcntl.go", "file.go", + "file_amd64.go", + "file_arm64.go", "fs.go", "futex.go", "inotify.go", diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 257f67222..c9ee098f4 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -186,25 +186,6 @@ const ( RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC ) -// Stat represents struct stat. -type Stat struct { - Dev uint64 - Ino uint64 - Nlink uint64 - Mode uint32 - UID uint32 - GID uint32 - _ int32 - Rdev uint64 - Size int64 - Blksize int64 - Blocks int64 - ATime Timespec - MTime Timespec - CTime Timespec - _ [3]int64 -} - // SizeOfStat is the size of a Stat struct. var SizeOfStat = binary.Size(Stat{}) diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go new file mode 100644 index 000000000..74c554be6 --- /dev/null +++ b/pkg/abi/linux/file_amd64.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go new file mode 100644 index 000000000..f16c07589 --- /dev/null +++ b/pkg/abi/linux/file_arm64.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Mode uint32 + Nlink uint32 + UID uint32 + GID uint32 + Rdev uint64 + _ uint64 + Size int64 + Blksize int32 + _ int32 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [2]int32 +} diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 152f6b144..3898d2314 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -325,3 +325,9 @@ const ( RTA_SPORT = 28 RTA_DPORT = 29 ) + +// Route flags, from include/uapi/linux/route.h. +const ( + RTF_GATEWAY = 0x2 + RTF_UP = 0x1 +) diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 25530adca..6039f5a42 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -814,7 +814,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) { attr.Mode = FileMode(s.Mode) } if req.NLink { - attr.NLink = s.Nlink + attr.NLink = uint64(s.Nlink) } if req.UID { attr.UID = UID(s.Uid) diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 30ec8e6e2..38cea9be3 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index e340d9f98..4f4b70fef 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 497475db5..75cbb0622 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -52,6 +52,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/tcpip/header", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index f70239449..402919924 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "io" + "reflect" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -33,6 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // newNet creates a new proc net entry. @@ -40,7 +43,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo var contents map[string]*fs.Inode if s := p.k.NetworkStack(); s != nil { contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a @@ -57,7 +61,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo // (ClockGetres returns 1ns resolution). "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")), - "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), @@ -195,6 +199,193 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netSnmp implements seqfile.SeqSource for /proc/net/snmp. +// +// +stateify savable +type netSnmp struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netSnmp) NeedsUpdate(generation int64) bool { + return true +} + +type snmpLine struct { + prefix string + header string +} + +var snmp = []snmpLine{ + { + prefix: "Ip", + header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates", + }, + { + prefix: "Icmp", + header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps", + }, + { + prefix: "IcmpMsg", + }, + { + prefix: "Tcp", + header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors", + }, + { + prefix: "Udp", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, + { + prefix: "UdpLite", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, +} + +func toSlice(a interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(a)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + +func sprintSlice(s []uint64) string { + if len(s) == 0 { + return "" + } + r := fmt.Sprint(s) + return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's +// net/core/net-procfs.c:dev_seq_show. +func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + contents := make([]string, 0, len(snmp)*2) + types := []interface{}{ + &inet.StatSNMPIP{}, + &inet.StatSNMPICMP{}, + nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats. + &inet.StatSNMPTCP{}, + &inet.StatSNMPUDP{}, + &inet.StatSNMPUDPLite{}, + } + for i, stat := range types { + line := snmp[i] + if stat == nil { + contents = append( + contents, + fmt.Sprintf("%s:\n", line.prefix), + fmt.Sprintf("%s:\n", line.prefix), + ) + continue + } + if err := n.s.Statistics(stat, line.prefix); err != nil { + if err == syserror.EOPNOTSUPP { + log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } else { + log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } + } + var values string + if line.prefix == "Tcp" { + tcp := stat.(*inet.StatSNMPTCP) + // "Tcp" needs special processing because MaxConn is signed. RFC 2012. + values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + } else { + values = sprintSlice(toSlice(stat)) + } + contents = append( + contents, + fmt.Sprintf("%s: %s\n", line.prefix, line.header), + fmt.Sprintf("%s: %s\n", line.prefix, values), + ) + } + + data := make([]seqfile.SeqData, 0, len(snmp)*2) + for _, l := range contents { + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)}) + } + + return data, 0 +} + +// netRoute implements seqfile.SeqSource for /proc/net/route. +// +// +stateify savable +type netRoute struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netRoute) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. +func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + interfaces := n.s.Interfaces() + contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"} + for _, rt := range n.s.RouteTable() { + // /proc/net/route only includes ipv4 routes. + if rt.Family != linux.AF_INET { + continue + } + + // /proc/net/route does not include broadcast or multicast routes. + if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST { + continue + } + + iface, ok := interfaces[rt.OutputInterface] + if !ok || iface.Name == "lo" { + continue + } + + var ( + gw uint32 + prefix uint32 + flags = linux.RTF_UP + ) + if len(rt.GatewayAddr) == header.IPv4AddressSize { + flags |= linux.RTF_GATEWAY + gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + } + if len(rt.DstAddr) == header.IPv4AddressSize { + prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + } + l := fmt.Sprintf( + "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", + iface.Name, + prefix, + gw, + flags, + 0, // RefCnt. + 0, // Use. + 0, // Metric. + (uint32(1)<<rt.DstLen)-1, + 0, // MTU. + 0, // Window. + 0, // RTT. + ) + contents = append(contents, l) + } + + var data []seqfile.SeqData + for _, l := range contents { + l = fmt.Sprintf("%-127s\n", l) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)}) + } + + return data, 0 +} + // netUnix implements seqfile.SeqSource for /proc/net/unix. // // +stateify savable diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index d5284f0d9..8d60ad4ad 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip/stack", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80f227dbe..a7dfb78a7 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,16 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // RegisteredEndpoints returns all endpoints which are currently registered. + RegisteredEndpoints() []stack.TransportEndpoint + + // CleanupEndpoints returns endpoints currently in the cleanup state. + CleanupEndpoints() []stack.TransportEndpoint + + // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful + // for restoring a stack after a save. + RestoreCleanupEndpoints([]stack.TransportEndpoint) } // Interface contains information about a network interface. @@ -153,3 +165,23 @@ type Route struct { // GatewayAddr is the route gateway address (RTA_GATEWAY). GatewayAddr []byte } + +// Below SNMP metrics are from Linux/usr/include/linux/snmp.h. + +// StatSNMPIP describes Ip line of /proc/net/snmp. +type StatSNMPIP [19]uint64 + +// StatSNMPICMP describes Icmp line of /proc/net/snmp. +type StatSNMPICMP [27]uint64 + +// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp. +type StatSNMPICMPMSG [512]uint64 + +// StatSNMPTCP describes Tcp line of /proc/net/snmp. +type StatSNMPTCP [15]uint64 + +// StatSNMPUDP describes Udp line of /proc/net/snmp. +type StatSNMPUDP [8]uint64 + +// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. +type StatSNMPUDPLite [8]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..dcfcbd97e 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,8 @@ package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route { } // Resume implements Stack.Resume. -func (s *TestStack) Resume() { +func (s *TestStack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { + return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { + return nil +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index cc3f43a45..11f613a11 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -81,6 +81,9 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` + // next is start position to find fd. + next int32 + // used contains the number of non-nil entries. It must be accessed // atomically. It may be read atomically without holding mu (but not // written). @@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.mu.Lock() defer f.mu.Unlock() + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { @@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, syscall.EMFILE } + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + return fds, nil } @@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File { f.mu.Lock() defer f.mu.Unlock() + // Update current available position. + if fd < f.next { + f.next = fd + } + orig, _, _ := f.get(fd) if orig != nil { orig.IncRef() // Reference for caller. @@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) { f.forEach(func(fd int32, file *fs.File, flags FDFlags) { if cond(file, flags) { f.set(fd, nil, FDFlags{}) // Clear from table. + // Update current available position. + if fd < f.next { + f.next = fd + } } }) } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2413788e7..2bcb6216a 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) { if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) } + + i := int32(2) + fdTable.Remove(i) + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { + t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) + } + }) +} + +func TestFDTableOverLimit(t *testing.T) { + runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { + if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") + } + + if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { + t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) + } else { + for _, fd := range fds { + fdTable.Remove(fd) + } + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { + t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) + } + + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { + t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) + } else if len(fds) != 1 || fds[0] != 0 { + t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) + } }) } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 3cda03891..e64d648e2 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -804,8 +804,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: mounts, + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: true, + Filename: args.Filename, + File: args.File, + CloseOnExec: false, + Argv: args.Argv, + Envv: args.Envv, + Features: k.featureSet, + } - tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) + tc, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index c82ef5486..9be3dae3c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -709,9 +709,9 @@ func (t *Task) FDTable() *FDTable { return t.fdTable } -// GetFile is a convenience wrapper t.FDTable().GetFile. +// GetFile is a convenience wrapper for t.FDTable().Get. // -// Precondition: same as FDTable. +// Precondition: same as FDTable.Get. func (t *Task) GetFile(fd int32) *fs.File { f, _ := t.fdTable.Get(fd) return f diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 8639d379f..bb5560acf 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -18,10 +18,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -132,30 +130,21 @@ func (t *Task) Stack() *arch.Stack { return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} } -// LoadTaskImage loads filename into a new TaskContext. +// LoadTaskImage loads a specified file into a new TaskContext. // -// It takes several arguments: -// * mounts: MountNamespace to lookup filename in -// * root: Root to lookup filename under -// * wd: Working directory to lookup filename under -// * maxTraversals: maximum number of symlinks to follow -// * filename: path to binary to load -// * file: an open fs.File object of the binary to load. If set, -// file will be loaded and not filename. -// * argv: Binary argv -// * envv: Binary envv -// * fs: Binary FeatureSet -func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving filename. - if file != nil { - filename = file.MappedName(ctx) +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.MappedName(ctx) } // Prepare a new user address space to load into. m := mm.NewMemoryManager(k, k) defer m.DecUsers(ctx) + args.MemoryManager = m - os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso) + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) if err != nil { return nil, err } diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 2d9251e92..c2c3ec06e 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -624,15 +624,15 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, in return loadParsedELF(ctx, m, f, info, 0) } -// loadELF loads f into the Task address space. +// loadELF loads args.File into the Task address space. // // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // // Preconditions: -// * f is an ELF file -func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) { - bin, ac, err := loadInitialELF(ctx, m, fs, f) +// * args.File is an ELF file +func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { + bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { ctx.Infof("Error loading binary: %v", err) return loadedELF{}, nil, err @@ -640,7 +640,14 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace var interp loadedELF if bin.interpreter != "" { - d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter) + // Even if we do not allow the final link of the script to be + // resolved, the interpreter should still be resolved if it is + // a symlink. + args.ResolveFinal = true + // Refresh the traversal limit. + *args.RemainingTraversals = linux.MaxSymlinkTraversals + args.Filename = bin.interpreter + d, i, err := openPath(ctx, args) if err != nil { ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err @@ -649,7 +656,7 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace // We don't need the Dirent. d.DecRef() - interp, err = loadInterpreterELF(ctx, m, i, bin) + interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin) if err != nil { ctx.Infof("Error loading interpreter: %v", err) return loadedELF{}, nil, err diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 089d1635b..b03eeb005 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package loader loads a binary into a MemoryManager. +// Package loader loads an executable file into a MemoryManager. package loader import ( @@ -20,6 +20,7 @@ import ( "fmt" "io" "path" + "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" @@ -35,6 +36,54 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LoadArgs holds specifications for an executable file to be loaded. +type LoadArgs struct { + // MemoryManager is the memory manager to load the executable into. + MemoryManager *mm.MemoryManager + + // Mounts is the mount namespace in which to look up Filename. + Mounts *fs.MountNamespace + + // Root is the root directory under which to look up Filename. + Root *fs.Dirent + + // WorkingDirectory is the working directory under which to look up + // Filename. + WorkingDirectory *fs.Dirent + + // RemainingTraversals is the maximum number of symlinks to follow to + // resolve Filename. This counter is passed by reference to keep it + // updated throughout the call stack. + RemainingTraversals *uint + + // ResolveFinal indicates whether the final link of Filename should be + // resolved, if it is a symlink. + ResolveFinal bool + + // Filename is the path for the executable. + Filename string + + // File is an open fs.File object of the executable. If File is not + // nil, then File will be loaded and Filename will be ignored. + File *fs.File + + // CloseOnExec indicates that the executable (or one of its parent + // directories) was opened with O_CLOEXEC. If the executable is an + // interpreter script, then cause an ENOENT error to occur, since the + // script would otherwise be inaccessible to the interpreter. + CloseOnExec bool + + // Argv is the vector of arguments to pass to the executable. + Argv []string + + // Envv is the vector of environment variables to pass to the + // executable. + Envv []string + + // Features specifies the CPU feature set for the executable. + Features *cpuid.FeatureSet +} + // readFull behaves like io.ReadFull for an *fs.File. func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { var total int64 @@ -51,80 +100,82 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in return total, nil } -// openPath opens name for loading. +// openPath opens args.Filename and checks that it is valid for loading. // -// openPath returns the fs.Dirent and an *fs.File for name, which is not +// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not // installed in the Task FDTable. The caller takes ownership of both. // -// name must be a readable, executable, regular file. -func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) { - if name == "" { +// args.Filename must be a readable, executable, regular file. +func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) { + if args.Filename == "" { ctx.Infof("cannot open empty name") return nil, nil, syserror.ENOENT } - d, err := mm.FindInode(ctx, root, wd, name, maxTraversals) + var d *fs.Dirent + var err error + if args.ResolveFinal { + d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } else { + d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } if err != nil { return nil, nil, err } - - // Open file will take a reference to Dirent, so destroy this one. + // Defer a DecRef for the sake of failure cases. defer d.DecRef() - return openFile(ctx, nil, d, name) -} + if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) { + return nil, nil, syserror.ELOOP + } -// openFile performs checks on a file to be executed. If provided a *fs.File, -// openFile takes that file's Dirent and performs checks on it. If provided a -// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's -// Inode and performs checks on that. -// -// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership -// of both. -// -// "dirent" and "file" must not both be nil and point to a readable, executable, regular file. -func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) { - // file and dirent must not be nil. - if dirent == nil && file == nil { - ctx.Infof("dirent and file cannot both be nil.") - return nil, nil, syserror.ENOENT + if err := checkPermission(ctx, d); err != nil { + return nil, nil, err } - if file != nil { - dirent = file.Dirent + // If they claim it's a directory, then make sure. + // + // N.B. we reject directories below, but we must first reject + // non-directories passed as directories. + if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) { + return nil, nil, syserror.ENOTDIR } - // Perform permissions checks on the file. - if err := checkFile(ctx, dirent, name); err != nil { + if err := checkIsRegularFile(ctx, d, args.Filename); err != nil { return nil, nil, err } - if file == nil { - var ferr error - if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil { - return nil, nil, ferr - } - } else { - // GetFile takes a reference to the created file, so make one in the case - // that the file reference already existed. - file.IncRef() + f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) + if err != nil { + return nil, nil, err + } + // Defer a DecRef for the sake of failure cases. + defer f.DecRef() + + if err := checkPread(ctx, f, args.Filename); err != nil { + return nil, nil, err + } + + d.IncRef() + f.IncRef() + return d, f, err +} + +// checkFile performs checks on a file to be executed. +func checkFile(ctx context.Context, f *fs.File, filename string) error { + if err := checkPermission(ctx, f.Dirent); err != nil { + return err } - // We must be able to read at arbitrary offsets. - if !file.Flags().Pread { - file.DecRef() - ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags()) - return nil, nil, syserror.EACCES + if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil { + return err } - // Grab reference for caller. - dirent.IncRef() - return dirent, file, nil + return checkPread(ctx, f, filename) } -// checkFile performs file permissions checks for binaries called in openPath -// and openFile -func checkFile(ctx context.Context, d *fs.Dirent, name string) error { +// checkPermission checks whether the file is readable and executable. +func checkPermission(ctx context.Context, d *fs.Dirent) error { perms := fs.PermMask{ // TODO(gvisor.dev/issue/160): Linux requires only execute // permission, not read. However, our backing filesystems may @@ -135,26 +186,26 @@ func checkFile(ctx context.Context, d *fs.Dirent, name string) error { Read: true, Execute: true, } - if err := d.Inode.CheckPermission(ctx, perms); err != nil { - return err - } + return d.Inode.CheckPermission(ctx, perms) +} - // If they claim it's a directory, then make sure. - // - // N.B. we reject directories below, but we must first reject - // non-directories passed as directories. - if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR +// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc. +func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error { + attr := d.Inode.StableAttr + if !fs.IsRegular(attr) { + ctx.Infof("%s is not regular: %v", filename, attr) + return syserror.EACCES } + return nil +} - // No exec-ing directories, pipes, etc! - if !fs.IsRegular(d.Inode.StableAttr) { - ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr) +// checkPread checks whether we can read the file at arbitrary offsets. +func checkPread(ctx context.Context, f *fs.File, filename string) error { + if !f.Flags().Pread { + ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags()) return syserror.EACCES } - return nil - } // allocStack allocates and maps a stack in to any available part of the address space. @@ -173,45 +224,49 @@ const ( maxLoaderAttempts = 6 ) -// loadBinary loads a binary that is pointed to by "file". If nil, the path -// "filename" is resolved and loaded. +// loadExecutable loads an executable that is pointed to by args.File. If nil, +// the path args.Filename is resolved and loaded. If the executable is an +// interpreter script rather than an ELF, the binary of the corresponding +// interpreter will be loaded. // // It returns: // * loadedELF, description of the loaded binary // * arch.Context matching the binary arch // * fs.Dirent of the binary file -// * Possibly updated argv -func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { +// * Possibly updated args.Argv +func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) { for i := 0; i < maxLoaderAttempts; i++ { var ( d *fs.Dirent - f *fs.File err error ) - if passedFile == nil { - d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename) - + if args.File == nil { + d, args.File, err = openPath(ctx, args) + // We will return d in the successful case, but defer a DecRef for the + // sake of intermediate loops and failure cases. + if d != nil { + defer d.DecRef() + } + if args.File != nil { + defer args.File.DecRef() + } } else { - d, f, err = openFile(ctx, passedFile, nil, "") - // Set to nil in case we loop on a Interpreter Script. - passedFile = nil + d = args.File.Dirent + d.IncRef() + defer d.DecRef() + err = checkFile(ctx, args.File, args.Filename) } - if err != nil { - ctx.Infof("Error opening %s: %v", filename, err) + ctx.Infof("Error opening %s: %v", args.Filename, err) return loadedELF{}, nil, nil, nil, err } - defer f.DecRef() - // We will return d in the successful case, but defer a DecRef - // for intermediate loops and failure cases. - defer d.DecRef() // Check the header. Is this an ELF or interpreter script? var hdr [4]uint8 // N.B. We assume that reading from a regular file cannot block. - _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0) - // Allow unexpected EOF, as a valid executable could be only three - // bytes (e.g., #!a). + _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0) + // Allow unexpected EOF, as a valid executable could be only three bytes + // (e.g., #!a). if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { err = syserror.ENOEXEC @@ -221,33 +276,38 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp switch { case bytes.Equal(hdr[:], []byte(elfMagic)): - loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f) + loaded, ac, err := loadELF(ctx, args) if err != nil { ctx.Infof("Error loading ELF: %v", err) return loadedELF{}, nil, nil, nil, err } // An ELF is always terminal. Hold on to d. d.IncRef() - return loaded, ac, d, argv, err + return loaded, ac, d, args.Argv, err case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)): - newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv) + if args.CloseOnExec { + return loadedELF{}, nil, nil, nil, syserror.ENOENT + } + args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv) if err != nil { ctx.Infof("Error loading interpreter script: %v", err) return loadedELF{}, nil, nil, nil, err } - filename = newpath - argv = newargv + // Refresh the traversal limit for the interpreter. + *args.RemainingTraversals = linux.MaxSymlinkTraversals default: ctx.Infof("Unknown magic: %v", hdr) return loadedELF{}, nil, nil, nil, syserror.ENOEXEC } + // Set to nil in case we loop on a Interpreter Script. + args.File = nil } return loadedELF{}, nil, nil, nil, syserror.ELOOP } -// Load loads "file" into a MemoryManager. If file is nil, the path "filename" -// is resolved and loaded instead. +// Load loads args.File into a MemoryManager. If args.File is nil, the path +// args.Filename is resolved and loaded instead. // // If Load returns ErrSwitchFile it should be called again with the returned // path and argv. @@ -255,37 +315,37 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp // Preconditions: // * The Task MemoryManager is empty. // * Load is called on the Task goroutine. -func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { - // Load the binary itself. - loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv) +func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { + // Load the executable itself. + loaded, ac, d, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer d.DecRef() // Load the VDSO. - vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded) + vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the - // binary. Userspace can assume that the remainer of the page after + // executable. Userspace can assume that the remainer of the page after // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) } - m.BrkSetup(ctx, e) + args.MemoryManager.BrkSetup(ctx, e) // Allocate our stack. - stack, err := allocStack(ctx, m, ac) + stack, err := allocStack(ctx, args.MemoryManager, ac) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux()) } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(filename) + execfn, err := stack.Push(args.Filename) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } @@ -319,11 +379,12 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r }...) auxv = append(auxv, extraAuxv...) - sl, err := stack.Load(argv, envv, auxv) + sl, err := stack.Load(newArgv, args.Envv, auxv) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux()) } + m := args.MemoryManager m.SetArgvStart(sl.ArgvStart) m.SetArgvEnd(sl.ArgvEnd) m.SetEnvvStart(sl.EnvvStart) @@ -334,7 +395,7 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) - name := path.Base(filename) + name := path.Base(args.Filename) if len(name) > linux.TASK_COMM_LEN-1 { name = name[:linux.TASK_COMM_LEN-1] } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 31fa48ec5..6803d488c 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -23,6 +23,7 @@ go_library( "machine.go", "machine_amd64.go", "machine_amd64_unsafe.go", + "machine_arm64.go", "machine_unsafe.go", "physical_map.go", "virtual_map.go", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index acd41f73d..ea8b9632e 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // not have physical mappings, the KVM module may inject // spurious exceptions when emulation fails (i.e. it tries to // emulate because the RIP is pointed at those pages). - as.machine.mapPhysical(physical, length) + as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) // Install the page table mappings. Note that the ordering is // important; if the pagetable mappings were installed before diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go index 80942e9c9..3f35414bb 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/allocator.go @@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { // //go:nosplit func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { - virtualStart, physicalStart, _, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) } diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index b97476053..f6459cda9 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -46,9 +46,9 @@ func yield() { // calculateBluepillFault calculates the fault address range. // //go:nosplit -func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { +func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { alignedPhysical := physical &^ uintptr(usermem.PageSize-1) - for _, pr := range physicalRegions { + for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { continue @@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng // The corresponding virtual address is returned. This may throw on error. // //go:nosplit -func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { +func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) { // Paging fault: we need to map the underlying physical pages for this // fault. This all has to be done in this function because we're in a // signal handler context. (We can't call any functions that might // split the stack.) - virtualStart, physicalStart, length, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { return 0, false } @@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { yield() // Race with another call. slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) } - errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { // Successfully added region; we can increment nextSlot and // allow another set to proceed here. diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 7e8e9f42a..3734bfb7a 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -162,7 +162,7 @@ func bluepillHandler(context unsafe.Pointer) { // For MMIO, the physical address is the first data item. physical := uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical) + virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) if !ok { c.die(bluepillArchContext(context), "invalid physical address") return diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index d05f05c29..766131d60 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -62,3 +62,10 @@ const ( _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 ) + +// KVM kvm_memory_region::flags. +const ( + _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0 + _KVM_MEM_READONLY = uint32(1) << 1 + _KVM_MEM_FLAGS_NONE = 0 +) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc6c138b2..7d02ebf19 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) + var physicalRegionsReadOnly []physicalRegion + var physicalRegionsAvailable []physicalRegion + + physicalRegionsReadOnly = rdonlyRegionsForSetMem() + physicalRegionsAvailable = availableRegionsForSetMem() + + // Map all read-only regions. + for _, r := range physicalRegionsReadOnly { + m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) + } + // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to @@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) { if excludeVirtualRegion(vr) { return // skip region. } + + for _, r := range physicalRegionsReadOnly { + if vr.virtual == r.virtual { + return + } + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length) + m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) virtual += length } }) @@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. -func (m *machine) mapPhysical(physical, length uintptr) { +func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { - _, physicalStart, length, ok := calculateBluepillFault(physical) + _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { // Should never happen. panic("mapPhysical on unknown physical address") @@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) { if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { // Not present in the cache; requires setting the slot. - if _, ok := handleBluepillFault(m, physical); !ok { + if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c1cbe33be..b99fe425e 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) { } } } + +// On x86 platform, the flags for "setMemoryRegion" can always be set as 0. +// There is no need to return read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + return nil +} + +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 506ec9af1..61227cafb 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -26,30 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - // loadSegments copies the current segments. // // This may be called from within the signal context and throws on error. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go new file mode 100644 index 000000000..b7e2cfb9d --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -0,0 +1,61 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// Get all read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + var rdonlyRegions []region + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return + } + + if !vr.accessType.Write && vr.accessType.Read { + rdonlyRegions = append(rdonlyRegions, vr.region) + } + }) + + for _, r := range rdonlyRegions { + physical, _, ok := translateToPhysical(r.virtual) + if !ok { + continue + } + + phyRegions = append(phyRegions, physicalRegion{ + region: region{ + virtual: r.virtual, + length: r.length, + }, + physical: physical, + }) + } + + return phyRegions +} + +// Get all available physicalRegions. +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + var excludeRegions []region + applyVirtualRegions(func(vr virtualRegion) { + if !vr.accessType.Write { + excludeRegions = append(excludeRegions, vr.region) + } + }) + + phyRegions = computePhysicalRegions(excludeRegions) + + return phyRegions +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 405e00292..ed9433311 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -35,6 +35,30 @@ func entersyscall() //go:linkname exitsyscall runtime.exitsyscall func exitsyscall() +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: uint32(flags), + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + // mapRunData maps the vCPU run data. func mapRunData(fd int) (*runData, error) { r, _, errno := syscall.RawSyscall6( diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index b699b057d..ddb1f41e3 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -335,7 +335,8 @@ func (t *thread) unexpectedStubExit() { // these cases, we don't need to panic. There is no reasons to // think that something wrong in gVisor. log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid) - syscall.Kill(os.Getpid(), syscall.SIGKILL) + pid := os.Getpid() + syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL)) } t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) } diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index b80a3604d..2ae6b9f9d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4d174dda4..8b66a719d 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip/stack", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 3a4fdec47..e67b46c9e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -16,8 +16,11 @@ package hostinet import ( "fmt" + "io" "io/ioutil" "os" + "reflect" + "strconv" "strings" "syscall" @@ -26,7 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -51,6 +56,8 @@ type Stack struct { tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool + netDevFile *os.File + netSNMPFile *os.File } // NewStack returns an empty Stack containing no configuration. @@ -98,6 +105,18 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if TCP SACK if enabled, setting to true") } + if f, err := os.Open("/proc/net/dev"); err != nil { + log.Warningf("Failed to open /proc/net/dev: %v", err) + } else { + s.netDevFile = f + } + + if f, err := os.Open("/proc/net/snmp"); err != nil { + log.Warningf("Failed to open /proc/net/snmp: %v", err) + } else { + s.netSNMPFile = f + } + return nil } @@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// getLine reads one line from proc file, with specified prefix. +// The last argument, withHeader, specifies if it contains line header. +func getLine(f *os.File, prefix string, withHeader bool) string { + data := make([]byte, 4096) + + if _, err := f.Seek(0, 0); err != nil { + return "" + } + + if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF { + return "" + } + + prefix = prefix + ":" + lines := strings.Split(string(data), "\n") + for _, l := range lines { + l = strings.TrimSpace(l) + if strings.HasPrefix(l, prefix) { + if withHeader { + withHeader = false + continue + } + return l + } + } + return "" +} + +func toSlice(i interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(i)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserror.EOPNOTSUPP + var ( + snmpTCP bool + rawLine string + sliceStat []uint64 + ) + + switch stat.(type) { + case *inet.StatDev: + if s.netDevFile == nil { + return fmt.Errorf("/proc/net/dev is not opened for hostinet") + } + rawLine = getLine(s.netDevFile, arg, false /* with no header */) + case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite: + if s.netSNMPFile == nil { + return fmt.Errorf("/proc/net/snmp is not opened for hostinet") + } + rawLine = getLine(s.netSNMPFile, arg, true) + default: + return syserr.ErrEndpointOperation.ToError() + } + + if rawLine == "" { + return fmt.Errorf("Failed to get raw line") + } + + parts := strings.SplitN(rawLine, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("Failed to get prefix from: %q", rawLine) + } + + sliceStat = toSlice(stat) + fields := strings.Fields(strings.TrimSpace(parts[1])) + if len(fields) != len(sliceStat) { + return fmt.Errorf("Failed to parse fields: %q", rawLine) + } + if _, ok := stat.(*inet.StatSNMPTCP); ok { + snmpTCP = true + } + for i := 0; i < len(sliceStat); i++ { + var err error + if snmpTCP && i == 3 { + var tmp int64 + // MaxConn field is signed, RFC 2012. + tmp, err = strconv.ParseInt(fields[i], 10, 64) + sliceStat[i] = uint64(tmp) // Convert back to int before use. + } else { + sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) + } + if err != nil { + return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + } + } + + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index debf9bce0..27c6692c4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -149,6 +149,8 @@ var Metrics = tcpip.Stats{ TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -279,7 +281,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { + if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -1053,8 +1055,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - var v tcpip.DelayOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1495,11 +1497,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o tcpip.DelayOption + var o int if v == 0 { o = 1 } - return syserr.TranslateNetstackError(ep.SetSockOpt(o)) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fda0156e5..a0db2d4fd 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() + switch stats := stat.(type) { + case *inet.StatSNMPIP: + ip := Metrics.IP + *stats = inet.StatSNMPIP{ + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + } + case *inet.StatSNMPICMP: + in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + *stats = inet.StatSNMPICMP{ + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + in.DstUnreachable.Value(), // InDestUnreachs. + in.TimeExceeded.Value(), // InTimeExcds. + in.ParamProblem.Value(), // InParmProbs. + in.SrcQuench.Value(), // InSrcQuenchs. + in.Redirect.Value(), // InRedirects. + in.Echo.Value(), // InEchos. + in.EchoReply.Value(), // InEchoReps. + in.Timestamp.Value(), // InTimestamps. + in.TimestampReply.Value(), // InTimestampReps. + in.InfoRequest.Value(), // InAddrMasks. + in.InfoReply.Value(), // InAddrMaskReps. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. + } + case *inet.StatSNMPTCP: + tcp := Metrics.TCP + // RFC 2012 (updates 1213): SNMPv2-MIB-TCP. + *stats = inet.StatSNMPTCP{ + 1, // RtoAlgorithm. + 200, // RtoMin. + 120000, // RtoMax. + (1<<64 - 1), // MaxConn. + tcp.ActiveConnectionOpenings.Value(), // ActiveOpens. + tcp.PassiveConnectionOpenings.Value(), // PassiveOpens. + tcp.FailedConnectionAttempts.Value(), // AttemptFails. + tcp.EstablishedResets.Value(), // EstabResets. + tcp.CurrentEstablished.Value(), // CurrEstab. + tcp.ValidSegmentsReceived.Value(), // InSegs. + tcp.SegmentsSent.Value(), // OutSegs. + tcp.Retransmits.Value(), // RetransSegs. + tcp.InvalidSegmentsReceived.Value(), // InErrs. + tcp.ResetsSent.Value(), // OutRsts. + tcp.ChecksumErrors.Value(), // InCsumErrors. + } + case *inet.StatSNMPUDP: + udp := Metrics.UDP + *stats = inet.StatSNMPUDP{ + udp.PacketsReceived.Value(), // InDatagrams. + udp.UnknownPortErrors.Value(), // NoPorts. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + udp.PacketsSent.Value(), // OutDatagrams. + udp.ReceiveBufferErrors.Value(), // RcvbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + } + default: + return syserr.ErrEndpointOperation.ToError() + } + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 3a6baa308..4668b87d1 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -37,6 +37,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f7878a760 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index fb2c1777f..4c0bf96e4 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -79,6 +79,7 @@ go_library( "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", + "//pkg/sentry/loader", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/safemem", diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index aedb6d774..3021440ed 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -362,7 +362,7 @@ var AMD64 = &kernel.SyscallTable{ 319: syscalls.Supported("memfd_create", MemfdCreate), 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 322: syscalls.PartiallySupported("execveat", Execveat, "No support for AT_EMPTY_PATH, AT_SYMLINK_FOLLOW.", nil), + 322: syscalls.Supported("execveat", Execveat), 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 6e425f1ec..4115116ff 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" + "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -105,38 +106,66 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user } } - if flags != 0 { - // TODO(b/128449944): Handle AT_EMPTY_PATH and AT_SYMLINK_NOFOLLOW. - t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.ENOSYS + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 + if !atEmptyPath && len(pathname) == 0 { + return 0, nil, syserror.ENOENT } + resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 root := t.FSContext().RootDirectory() defer root.DecRef() var wd *fs.Dirent + var executable *fs.File + var closeOnExec bool if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) { - // If pathname is absolute, LoadTaskImage() will ignore the wd. + // Even if the pathname is absolute, we may still need the wd + // for interpreter scripts if the path of the interpreter is + // relative. wd = t.FSContext().WorkingDirectory() } else { // Need to extract the given FD. - f := t.GetFile(dirFD) + f, fdFlags := t.FDTable().Get(dirFD) if f == nil { return 0, nil, syserror.EBADF } defer f.DecRef() + closeOnExec = fdFlags.CloseOnExec - wd = f.Dirent - wd.IncRef() - if !fs.IsDir(wd.Inode.StableAttr) { - return 0, nil, syserror.ENOTDIR + if atEmptyPath && len(pathname) == 0 { + executable = f + } else { + wd = f.Dirent + wd.IncRef() + if !fs.IsDir(wd.Inode.StableAttr) { + return 0, nil, syserror.ENOTDIR + } } } - defer wd.DecRef() + if wd != nil { + defer wd.DecRef() + } // Load the new TaskContext. - maxTraversals := uint(linux.MaxSymlinkTraversals) - tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, pathname, nil, argv, envv, t.Arch().FeatureSet()) + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: t.MountNamespace(), + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: resolveFinal, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b0511aa40..75e6c7dfa 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 8f5e60a25..acbf0229b 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index 4cecfb989..b6fa6fc37 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 02137e1c9..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -639,6 +640,8 @@ func ICMPv4Code(want byte) TransportChecker { // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and // potentially additional ICMPv6 header fields. +// +// ICMPv6 will validate the checksum field before calling checkers. func ICMPv6(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -650,6 +653,10 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) + if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) + } + for _, f := range checkers { f(t, icmp) } diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index c2bfd8c79..b4037b6c8 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -132,7 +132,7 @@ func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } -// SetChecksum calculates and sets the ICMP checksum field. +// SetChecksum sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } @@ -197,7 +197,7 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } -// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { // Calculate the IPv6 pseudo-header upper-layer checksum. diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index b28bde15b..a2b9d7435 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -15,6 +15,10 @@ package header import ( + "encoding/binary" + "errors" + "time" + "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,15 +31,187 @@ const ( // Link Layer Option for an Ethernet address. ndpTargetEthernetLinkLayerAddressSize = 8 + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType = 3 + + // ndpPrefixInformationLength is the expected length, in bytes, of the + // body of an NDP Prefix Information option, as per RFC 4861 section + // 4.6.2 which specifies that the Length field is 4. Given this, the + // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2 + // (Type & Length) = 30. + ndpPrefixInformationLength = 30 + + // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix + // Length field within an NDPPrefixInformation. + ndpPrefixInformationPrefixLengthOffset = 0 + + // ndpPrefixInformationFlagsOffset is the offset of the flags byte + // within an NDPPrefixInformation. + ndpPrefixInformationFlagsOffset = 1 + + // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationOnLinkFlagMask = (1 << 7) + + // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the + // Autonomous Address-Configuration flag field in the flags byte within + // an NDPPrefixInformation. + ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + + // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationReserved1FlagsMask = 63 + + // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte + // Valid Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationValidLifetimeOffset = 2 + + // ndpPrefixInformationPreferredLifetimeOffset is the start of the + // 4-byte Preferred Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationPreferredLifetimeOffset = 6 + + // ndpPrefixInformationReserved2Offset is the start of the 4-byte + // Reserved2 field within an NDPPrefixInformation. + ndpPrefixInformationReserved2Offset = 10 + + // ndpPrefixInformationReserved2Length is the length of the Reserved2 + // field. + // + // It is 4 bytes. + ndpPrefixInformationReserved2Length = 4 + + // ndpPrefixInformationPrefixOffset is the start of the Prefix field + // within an NDPPrefixInformation. + ndpPrefixInformationPrefixOffset = 14 + + // NDPPrefixInformationInfiniteLifetime is a value that represents + // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix + // Information option. Its value is (2^32 - 1)s = 4294967295s + NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 ) +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + // The NDPOptions this NDPOptionIterator is iterating over. + opts NDPOptions +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") + ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occured. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if len(i.opts) == 0 { + return nil, true, nil + } + + // Do we have enough bytes for an NDP option that has a Length + // field of at least 1? Note, 0 in the Length field is invalid. + if len(i.opts) < lengthByteUnits { + return nil, true, ErrNDPOptBufExhausted + } + + // Get the Type field. + t := i.opts[0] + + // Get the Length field. + l := i.opts[1] + + // This would indicate an erroneous NDP option as the Length + // field should never be 0. + if l == 0 { + return nil, true, ErrNDPOptZeroLength + } + + // How many bytes are in the option body? + numBytes := int(l) * lengthByteUnits + numBodyBytes := numBytes - 2 + + potentialBody := i.opts[2:] + + // This would indicate an erroenous NDPOptions buffer as we ran + // out of the buffer in the middle of an NDP option. + if left := len(potentialBody); left < numBodyBytes { + return nil, true, ErrNDPOptBufExhausted + } + + // Get only the options body, leaving the rest of the options + // buffer alone. + body := potentialBody[:numBodyBytes] + + // Update opts with the remaining options body. + i.opts = i.opts[numBytes:] + + switch t { + case NDPTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case NDPPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, ErrNDPOptMalformedBody + } + + return NDPPrefixInformation(body), false, nil + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + // NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. type NDPOptions []byte +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{opts: b} + + if check { + for it2 := it; true; { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + // Serialize serializes the provided list of NDP options into o. // // Note, b must be of sufficient size to hold all the options in s. See @@ -75,15 +251,15 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { return done } -// ndpOption is the set of functions to be implemented by all NDP option types. -type ndpOption interface { - // Type returns the type of this ndpOption. +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + // Type returns the type of the receiver. Type() uint8 - // Length returns the length of the body of this ndpOption, in bytes. + // Length returns the length of the body of the receiver, in bytes. Length() int - // serializeInto serializes this ndpOption into the provided byte + // serializeInto serializes the receiver into the provided byte // buffer. // // Note, the caller MUST provide a byte buffer with size of at least @@ -92,15 +268,15 @@ type ndpOption interface { // buffer is not of sufficient size. // // serializeInto will return the number of bytes that was used to - // serialize this ndpOption. Implementers must only use the number of - // bytes required to serialize this ndpOption. Callers MAY provide a + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a // larger buffer than required to serialize into. serializeInto([]byte) int } // paddedLength returns the length of o, in bytes, with any padding bytes, if // required. -func paddedLength(o ndpOption) int { +func paddedLength(o NDPOption) int { l := o.Length() if l == 0 { @@ -139,7 +315,7 @@ func paddedLength(o ndpOption) int { } // NDPOptionsSerializer is a serializer for NDP options. -type NDPOptionsSerializer []ndpOption +type NDPOptionsSerializer []NDPOption // Length returns the total number of bytes required to serialize. func (b NDPOptionsSerializer) Length() int { @@ -154,19 +330,134 @@ func (b NDPOptionsSerializer) Length() int { // NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option // as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPTargetLinkLayerAddressOption tcpip.LinkAddress -// Type implements ndpOption.Type. +// Type implements NDPOption.Type. func (o NDPTargetLinkLayerAddressOption) Type() uint8 { return NDPTargetLinkLayerAddressOptionType } -// Length implements ndpOption.Length. +// Length implements NDPOption.Length. func (o NDPTargetLinkLayerAddressOption) Length() int { return len(o) } -// serializeInto implements ndpOption.serializeInto. +// serializeInto implements NDPOption.serializeInto. func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPPrefixInformation is the NDP Prefix Information option as defined by +// RFC 4861 section 4.6.2. +// +// The length, in bytes, of a valid NDP Prefix Information option body MUST be +// ndpPrefixInformationLength bytes. +type NDPPrefixInformation []byte + +// Type implements NDPOption.Type. +func (o NDPPrefixInformation) Type() uint8 { + return NDPPrefixInformationType +} + +// Length implements NDPOption.Length. +func (o NDPPrefixInformation) Length() int { + return ndpPrefixInformationLength +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPPrefixInformation) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the Reserved1 field. + b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask + + // Zero out the Reserved2 field. + reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length] + for i := range reserved2 { + reserved2[i] = 0 + } + + return used +} + +// PrefixLength returns the value in the number of leading bits in the Prefix +// that are valid. +// +// Valid values are in the range [0, 128], but o may not always contain valid +// values. It is up to the caller to valdiate the Prefix Information option. +func (o NDPPrefixInformation) PrefixLength() uint8 { + return o[ndpPrefixInformationPrefixLengthOffset] +} + +// OnLinkFlag returns true of the prefix is considered on-link. On-link means +// that a forwarding node is not needed to send packets to other nodes on the +// same prefix. +// +// Note, when this function returns false, no statement is made about the +// on-link property of a prefix. That is, if OnLinkFlag returns false, the +// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any +// previously stored state for this prefix about its on-link status. +func (o NDPPrefixInformation) OnLinkFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0 +} + +// AutonomousAddressConfigurationFlag returns true if the prefix can be used for +// Stateless Address Auto-Configuration (as specified in RFC 4862). +func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0 +} + +// ValidLifetime returns the length of time that the prefix is valid for the +// purpose of on-link determination. This value is relative to the send time of +// the packet that the Prefix Information option was present in. +// +// Note, a value of 0 implies the prefix should not be considered as on-link, +// and a value of infinity/forever is represented by +// NDPPrefixInformationInfiniteLifetime. +func (o NDPPrefixInformation) ValidLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) +} + +// PreferredLifetime returns the length of time that an address generated from +// the prefix via Stateless Address Auto-Configuration remains preferred. This +// value is relative to the send time of the packet that the Prefix Information +// option was present in. +// +// Note, a value of 0 implies that addresses generated from the prefix should +// no longer remain preferred, and a value of infinity is represented by +// NDPPrefixInformationInfiniteLifetime. +// +// Also note that the value of this field MUST NOT exceed the Valid Lifetime +// field to avoid preferring addresses that are no longer valid, for the +// purpose of Stateless Address Auto-Configuration. +func (o NDPPrefixInformation) PreferredLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:])) +} + +// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix +// Length field (see NDPPrefixInformation.PrefixLength) contains the number +// of valid leading bits in the prefix. +// +// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field +// holds the link-local prefix (fe80::). +func (o NDPPrefixInformation) Prefix() tcpip.Address { + return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 0aac14f43..ad6daafcd 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -36,18 +36,18 @@ func TestNDPNeighborSolicit(t *testing.T) { ns := NDPNeighborSolicit(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := ns.TargetAddress(); got != addr { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } } @@ -65,56 +65,56 @@ func TestNDPNeighborAdvert(t *testing.T) { na := NDPNeighborAdvert(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := na.TargetAddress(); got != addr { - t.Fatalf("got TargetAddress = %s, want %s", got, addr) + t.Errorf("got TargetAddress = %s, want %s", got, addr) } // Test getting the Router Flag. if got := na.RouterFlag(); !got { - t.Fatalf("got RouterFlag = false, want = true") + t.Errorf("got RouterFlag = false, want = true") } // Test getting the Solicited Flag. if got := na.SolicitedFlag(); got { - t.Fatalf("got SolicitedFlag = true, want = false") + t.Errorf("got SolicitedFlag = true, want = false") } // Test getting the Override Flag. if got := na.OverrideFlag(); !got { - t.Fatalf("got OverrideFlag = false, want = true") + t.Errorf("got OverrideFlag = false, want = true") } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { - t.Fatalf("got TargetAddress = %s, want %s", got, addr2) + t.Errorf("got TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } // Test updating the Router Flag. na.SetRouterFlag(false) if got := na.RouterFlag(); got { - t.Fatalf("got RouterFlag = true, want = false") + t.Errorf("got RouterFlag = true, want = false") } // Test updating the Solicited Flag. na.SetSolicitedFlag(true) if got := na.SolicitedFlag(); !got { - t.Fatalf("got SolicitedFlag = false, want = true") + t.Errorf("got SolicitedFlag = false, want = true") } // Test updating the Override Flag. na.SetOverrideFlag(false) if got := na.OverrideFlag(); got { - t.Fatalf("got OverrideFlag = true, want = false") + t.Errorf("got OverrideFlag = true, want = false") } // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Fatalf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64") } } @@ -128,30 +128,66 @@ func TestNDPRouterAdvert(t *testing.T) { ra := NDPRouterAdvert(b) if got := ra.CurrHopLimit(); got != 64 { - t.Fatalf("got ra.CurrHopLimit = %d, want = 64", got) + t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) } if got := ra.ManagedAddrConfFlag(); !got { - t.Fatalf("got ManagedAddrConfFlag = false, want = true") + t.Errorf("got ManagedAddrConfFlag = false, want = true") } if got := ra.OtherConfFlag(); got { - t.Fatalf("got OtherConfFlag = true, want = false") + t.Errorf("got OtherConfFlag = true, want = false") } if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Fatalf("got ra.RouterLifetime = %d, want = %d", got, want) + t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) } if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Fatalf("got ra.ReachableTime = %d, want = %d", got, want) + t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) } if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Fatalf("got ra.RetransTimer = %d, want = %d", got, want) + t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) } } +// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "TLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "TLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tll := NDPTargetLinkLayerAddressOption(test.buf) + if got := tll.EthernetAddress(); got != test.expected { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } + +} + // TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a // NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { @@ -194,6 +230,341 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { if !bytes.Equal(test.buf, test.expectedBuf) { t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll := next.(NDPTargetLinkLayerAddressOption) + if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got tll.MACAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPPrefixInformationOption tests the field getters and serialization of a +// NDPPrefixInformation. +func TestNDPPrefixInformationOption(t *testing.T) { + b := []byte{ + 43, 127, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 5, 5, 5, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPPrefixInformation(b), + } + opts.Serialize(serializer) + expectedBuf := []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + if !bytes.Equal(targetBuf, expectedBuf) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + pi := next.(NDPPrefixInformation) + + if got := pi.Type(); got != 3 { + t.Errorf("got Type = %d, want = 3", got) + } + + if got := pi.Length(); got != 30 { + t.Errorf("got Length = %d, want = 30", got) + } + + if got := pi.PrefixLength(); got != 43 { + t.Errorf("got PrefixLength = %d, want = 43", got) + } + + if pi.OnLinkFlag() { + t.Error("got OnLinkFlag = true, want = false") + } + + if !pi.AutonomousAddressConfigurationFlag() { + t.Error("got AutonomousAddressConfigurationFlag = false, want = true") + } + + if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { + t.Errorf("got ValidLifetime = %d, want = %d", got, want) + } + + if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { + t.Errorf("got PreferredLifetime = %d, want = %d", got, want) + } + + if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { + t.Errorf("got Prefix = %s, want = %s", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions +// the iterator was returned for is malformed. +func TestNDPOptionsIterCheck(t *testing.T) { + tests := []struct { + name string + buf []byte + expected error + }{ + { + "ZeroLengthField", + []byte{0, 0, 0, 0, 0, 0, 0, 0}, + ErrNDPOptZeroLength, + }, + { + "ValidTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + nil, + }, + { + "TooSmallTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5}, + ErrNDPOptBufExhausted, + }, + { + "ValidPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "TooSmallPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, + }, + ErrNDPOptBufExhausted, + }, + { + "InvalidPrefixInformationLength", + []byte{ + 3, 3, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + }, + ErrNDPOptMalformedBody, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformation", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up + // being the type for some recognized type, + // update 255 to some other unrecognized value. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + + if _, err := opts.Iter(true); err != test.expected { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + } + + // test.buf may be malformed but we chose not to check + // the iterator so it must return true. + if _, err := opts.Iter(false); err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } }) } } + +// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, +// this test does not actually check any of the option's getters, it simply +// checks the option Type and Body. We have other tests that tests the option +// field gettings given an option body and don't need to duplicate those tests +// here. +func TestNDPOptionsIter(t *testing.T) { + buf := []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up being the type + // for some recognized type, update 255 to some other + // unrecognized value. Note, this option should be skipped when + // iterating. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + opts := NDPOptions(buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Test the first (Taret Link-Layer) option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + + // Test the next (Prefix Information) option. + // Note, the unrecognized option should be skipped. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 59378b96c..e7c05ca4f 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents buffer.VectorisedView linkHeader buffer.View } @@ -94,7 +94,7 @@ func (c *context) cleanup() { } func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader} + c.ch <- packetInfo{remote, protocol, vv, linkHeader} } func TestNoEthernetProperties(t *testing.T) { @@ -319,13 +319,17 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: b, + contents: buffer.View(b).ToVectorisedView(), linkHeader: buffer.View(hdr), } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents will be a single view, + // so make pi do the same for the + // DeepEqual check. + pi.contents = pi.contents.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 12168a1dc..3331b6453 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) + vv := buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)) vv.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) @@ -293,7 +293,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) + vv := buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)) vv.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index dda3b10a6..0b5a6cf49 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 8d74497ba..666d8b92a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -519,6 +519,7 @@ func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" cases := []struct { name string expectedCount int @@ -570,7 +571,7 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + SrcAddr: outerSrcAddr, DstAddr: localIpv6Addr, }) @@ -618,6 +619,10 @@ func TestIPv6ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() + + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + ep.HandlePacket(&r, vv) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6c14b4aae..c3f1dd488 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -72,6 +72,18 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V h := header.ICMPv6(v) iph := header.IPv6(netHeader) + // Validate ICMPv6 checksum before processing the packet. + // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. + // This copy is used as extra payload during the checksum calculation. + payload := vv + payload.RemoveFirst() + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field // in the IPv6 header is not set to 255. @@ -246,15 +258,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } - // At this point we know that the targetAddress is not tentaive + // At this point we know that the targetAddress is not tentative // on rxNICID. However, targetAddr may still be assigned to // rxNICID but not tentative (it could be permanent). Such a // scenario is beyond the scope of RFC 4862. As such, we simply // ignore such a scenario for now and proceed as normal. // - // TODO(b/140896005): Handle the scenario described above - // (inform the netstack integration that a duplicate address was - // was detected) + // TODO(b/143147598): Handle the scenario described above. Also + // inform the netstack integration that a duplicate address was + // detected outside of DAD. e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7c11dde55..b112303b6 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -30,7 +30,7 @@ import ( ) const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) @@ -359,3 +359,533 @@ func TestLinkResolution(t *testing.T) { routeICMPv6Packet(t, args, nil) } } + +func TestICMPChecksumValidationSimple(t *testing.T) { + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + "RouterSolicit", + header.ICMPv6RouterSolicit, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + "RouterAdvert", + header.ICMPv6RouterAdvert, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + "NeighborSolicit", + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborSolicitMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + "NeighborAdvert", + header.ICMPv6NeighborAdvert, + header.ICMPv6NeighborAdvertSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + "RedirectMsg", + header.ICMPv6RedirectMsg, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.Inject(ProtocolNumber, + buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, + []buffer.View{hdr.View(), payload})) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 921d1c9c7..03ddebdbd 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -51,6 +51,22 @@ const ( minimumRetransmitTimer = time.Millisecond ) +// NDPDispatcher is the interface integrators of netstack must implement to +// receive and handle NDP related events. +type NDPDispatcher interface { + // OnDuplicateAddressDetectionStatus will be called when the DAD process + // for an address (addr) on a NIC (with ID nicid) completes. resolved + // will be set to true if DAD completed successfully (no duplicate addr + // detected); false otherwise (addr was detected to be a duplicate on + // the link the NIC is a part of, or it was stopped for some other + // reason, such as the address being removed). If an error occured + // during DAD, err will be set and resolved must be ignored. + // + // This function is permitted to block indefinitely without interfering + // with the stack's operation. + OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Neighbor Solicitation messages to send when doing @@ -88,6 +104,12 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // The NIC this ndpState is for. + nic *NIC + + // configs is the per-interface NDP configurations. + configs NDPConfigurations + // The DAD state to send the next NS message, or resolve the address. dad map[tcpip.Address]dadState } @@ -110,8 +132,8 @@ type dadState struct { // This function must only be called by IPv6 addresses that are currently // tentative. // -// The NIC that ndp belongs to (n) MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { return tcpip.ErrAddressFamilyNotSupported @@ -127,13 +149,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, // reference count would have been increased without doing the // work that would have been done for an address that was brand // new. See NIC.addPermanentAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, n.ID())) + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } - remaining := n.stack.ndpConfigs.DupAddrDetectTransmits + remaining := ndp.configs.DupAddrDetectTransmits { - done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref) + done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) if err != nil { return err } @@ -146,42 +168,59 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, var done bool var timer *time.Timer - timer = time.AfterFunc(n.stack.ndpConfigs.RetransmitTimer, func() { - n.mu.Lock() - defer n.mu.Unlock() + timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { + var d bool + var err *tcpip.Error + + // doDadIteration does a single iteration of the DAD loop. + // + // Returns true if the integrator needs to be informed of DAD + // completing. + doDadIteration := func() bool { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if done { + // If we reach this point, it means that the DAD + // timer fired after another goroutine already + // obtained the NIC lock and stopped DAD before + // this function obtained the NIC lock. Simply + // return here and do nothing further. + return false + } - if done { - // If we reach this point, it means that the DAD timer - // fired after another goroutine already obtained the - // NIC lock and stopped DAD before it this function - // obtained the NIC lock. Simply return here and do - // nothing further. - return - } + ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] + if !ok { + // This should never happen. + // We should have an endpoint for addr since we + // are still performing DAD on it. If the + // endpoint does not exist, but we are doing DAD + // on it, then we started DAD at some point, but + // forgot to stop it when the endpoint was + // deleted. + panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) + } - ref, ok := n.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we are - // still performing DAD on it. If the endpoint does not - // exist, but we are doing DAD on it, then we started - // DAD at some point, but forgot to stop it when the - // endpoint was deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, n.ID())) - } + d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil || d { + delete(ndp.dad, addr) - if done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref); err != nil || done { - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, n.ID(), err) + if err != nil { + log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) + } + + // Let the integrator know DAD has completed. + return true } - ndp.stopDuplicateAddressDetection(addr) - return + remaining-- + timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + return false } - timer.Reset(n.stack.ndpConfigs.RetransmitTimer) - remaining-- - + if doDadIteration() && ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + } }) ndp.dad[addr] = dadState{ @@ -204,11 +243,11 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, // The NIC that ndp belongs to (n) MUST be locked. // // Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { +func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { if ref.getKind() != permanentTentative { // The endpoint should still be marked as tentative // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, n.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) } if remaining == 0 { @@ -219,17 +258,17 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem // Send a new NS. snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := n.endpoints[NetworkEndpointID{snmc}] + snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] if !ok { // This should never happen as if we have the // address, we should have the solicited-node // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", n.ID(), snmc, addr)) + panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) } // Use the unspecified address as the source address when performing // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, n.linkEP.LinkAddress(), snmcRef, false, false) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) @@ -275,5 +314,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) - return + // Let the integrator know DAD did not resolve. + if ndp.nic.stack.ndpDisp != nil { + go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8995fbfc3..525a25218 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -31,6 +31,7 @@ import ( const ( addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" linkAddr1 = "\x02\x02\x03\x04\x05\x06" ) @@ -67,6 +68,35 @@ func TestDADDisabled(t *testing.T) { } } +// ndpDADEvent is a set of parameters that was passed to +// ndpDispatcher.OnDuplicateAddressDetectionStatus. +type ndpDADEvent struct { + nicid tcpip.NICID + addr tcpip.Address + resolved bool + err *tcpip.Error +} + +var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) + +// ndpDispatcher implements NDPDispatcher so tests can know when various NDP +// related events happen for test purposes. +type ndpDispatcher struct { + dadC chan ndpDADEvent +} + +// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +// +// If the DAD event matches what we are expecting, send signal on n.dadC. +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { + n.dadC <- ndpDADEvent{ + nicid, + addr, + resolved, + err, + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -88,8 +118,12 @@ func TestDADResolve(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits @@ -106,8 +140,7 @@ func TestDADResolve(t *testing.T) { stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - // Should have sent an NDP NS almost immediately. - time.Sleep(100 * time.Millisecond) + // Should have sent an NDP NS immediately. if got := stat.Value(); got != 1 { t.Fatalf("got NeighborSolicit = %d, want = 1", got) @@ -123,16 +156,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the remaining time - 500ms, to make sure - // the address is still not resolved. Note, we subtract - // 600ms because we already waited for 100ms earlier, - // so our remaining time is 100ms less than the expected - // time. - // (X - 100ms) - 500ms = X - 600ms - // - // TODO(b/140896005): Use events from the netstack to - // be signalled before DAD resolves. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - 600*time.Millisecond) + // Wait for the remaining time - some delta (500ms), to + // make sure the address is still not resolved. + const delta = 500 * time.Millisecond + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -141,13 +168,30 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the remaining time + 250ms, at which point - // the address should be resolved. Note, the remaining - // time is 500ms. See above comments. - // - // TODO(b/140896005): Use events from the netstack to - // know immediately when DAD completes. - time.Sleep(750 * time.Millisecond) + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicid != 1 { + t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) @@ -250,9 +294,14 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.DefaultNDPConfigurations(), + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = time.Second * 2 @@ -286,8 +335,28 @@ func TestDADFail(t *testing.T) { t.Fatalf("got stat = %d, want = 1", got) } - // Wait 3 seconds to make sure that DAD did not resolve - time.Sleep(3 * time.Second) + // Wait for DAD to fail and make sure the address did + // not get resolved. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicid != 1 { + t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -302,11 +371,18 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, } - opts.NDPConfigs.RetransmitTimer = time.Second - opts.NDPConfigs.DupAddrDetectTransmits = 2 e := channel.New(10, 1280, linkAddr1) s := stack.New(opts) @@ -332,11 +408,27 @@ func TestDADStop(t *testing.T) { t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) } - // Wait for the time to normally resolve - // DupAddrDetectTransmits(2) * RetransmitTimer(1s) = 2s. - // An extra 250ms is added to make sure that if DAD was still running - // it resolves and the check below fails. - time.Sleep(2*time.Second + 250*time.Millisecond) + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicid != 1 { + t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -350,3 +442,168 @@ func TestDADStop(t *testing.T) { t.Fatalf("got NeighborSolicit = %d, want <= 1", got) } } + +// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NDP configurations using an invalid NICID. +func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) + } +} + +// TestSetNDPConfigurations tests that we can update and use per-interface NDP +// configurations without affecting the default NDP configurations or other +// interfaces' configurations. +func TestSetNDPConfigurations(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransmitTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + { + "OK", + 1, + time.Second, + time.Second, + }, + { + "Invalid Retransmit Timer", + 1, + 0, + time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + }) + + // This NIC(1)'s NDP configurations will be updated to + // be different from the default. + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Created before updating NIC(1)'s NDP configurations + // but updating NIC(1)'s NDP configurations should not + // affect other existing NICs. + if err := s.CreateNIC(2, e); err != nil { + t.Fatalf("CreateNIC(2) = %s", err) + } + + // Update the NDP configurations on NIC(1) to use DAD. + configs := stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + } + if err := s.SetNDPConfigurations(1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + } + + // Created after updating NIC(1)'s NDP configurations + // but the stack's default NDP configurations should not + // have been updated. + if err := s.CreateNIC(3, e); err != nil { + t.Fatalf("CreateNIC(3) = %s", err) + } + + // Add addresses for each NIC. + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + } + if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + } + + // Address should not be considered bound to NIC(1) yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should get the address on NIC(2) and NIC(3) + // immediately since we should not have performed DAD on + // it as the stack was configured to not do DAD by + // default and we only updated the NDP configurations on + // NIC(1). + addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + } + if addr.Address != addr2 { + t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + } + addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + } + if addr.Address != addr3 { + t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + } + + // Sleep until right (500ms before) before resolution to + // make sure the address didn't resolve on NIC(1) yet. + const delta = 500 * time.Millisecond + time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicid != 1 { + t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + } + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e456e05f4..fe8f83d58 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -46,6 +46,10 @@ type NIC struct { stats NICStats + // ndp is the NDP related state for NIC. + // + // Note, read and write operations on ndp require that the NIC is + // appropriately locked. ndp ndpState } @@ -80,10 +84,16 @@ const ( NeverPrimaryEndpoint ) +// newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + nic := &NIC{ stack: stack, id: id, @@ -105,9 +115,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - dad: make(map[tcpip.Address]dadState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), }, } + nic.ndp.nic = nic // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { @@ -432,7 +444,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar // If we are adding a tentative IPv6 address, start DAD. if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(n, protocolAddress.AddressWithPrefix.Address, ref); err != nil { + if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -750,7 +762,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, vv, linkHeader) + ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -936,6 +948,18 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return n.removePermanentAddressLocked(addr) } +// setNDPConfigs sets the NDP configurations for n. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNDPConfigs(c NDPConfigurations) { + c.validate() + + n.mu.Lock() + n.ndp.configs = c + n.mu.Unlock() +} + type networkEndpointKind int32 const ( diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0869fb084..d7c124e81 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -60,13 +60,34 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. + // + // HandlePacket takes ownership of vv. HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. + // + // HandleControlPacket takes ownership of vv. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw @@ -77,6 +98,8 @@ type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. + // + // HandlePacket takes ownership of packet and netHeader. HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) } @@ -93,6 +116,8 @@ type PacketEndpoint interface { // // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of packet and linkHeader. HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View) } @@ -143,10 +168,14 @@ type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate // transport protocol endpoint. It also returns the network layer // header for the enpoint to inspect or pass up the stack. + // + // DeliverTransportPacket takes ownership of vv and netHeader. DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. + // + // DeliverTransportControlPacket takes ownership of vv. DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) } @@ -220,6 +249,8 @@ type NetworkEndpoint interface { // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. + // + // HandlePacket takes ownership of vv. HandlePacket(r *Route, vv buffer.VectorisedView) // Close is called when the endpoint is reomved from a stack. @@ -265,6 +296,8 @@ type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. linkHeader may have // length 0 when the caller does not have ethernet data. + // + // DeliverNetworkPacket takes ownership of vv and linkHeader. DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 284280917..115a6fcb8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -344,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -361,9 +369,10 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -399,13 +408,25 @@ type Stack struct { // TODO(gvisor.dev/issue/940): S/R this field. portSeed uint32 - // ndpConfigs is the NDP configurations used by interfaces. + // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations // autoGenIPv6LinkLocal determines whether or not the stack will attempt // to auto-generate an IPv6 link-local address for newly enabled NICs. // See the AutoGenIPv6LinkLocal field of Options for more details. autoGenIPv6LinkLocal bool + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -429,7 +450,10 @@ type Options struct { // stack (false). HandleLocal bool - // NDPConfigs is the NDP configurations used by interfaces. + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID + + // NDPConfigs is the default NDP configurations used by interfaces. // // By default, NDPConfigs will have a zero value for its // DupAddrDetectTransmits field, implying that DAD will not be performed @@ -448,6 +472,10 @@ type Options struct { // guidelines. AutoGenIPv6LinkLocal bool + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory @@ -497,6 +525,10 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() @@ -505,6 +537,7 @@ func New(opts Options) *Stack { networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, @@ -514,6 +547,8 @@ func New(opts Options) *Stack { portSeed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, + ndpDisp: opts.NDPDisp, } // Add specified network protocols. @@ -540,6 +575,11 @@ func New(opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value @@ -1127,6 +1167,25 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1148,6 +1207,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { @@ -1416,6 +1538,25 @@ func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tc return nic.dupTentativeAddrDetected(addr) } +// SetNDPConfigurations sets the per-interface NDP configurations on the NIC +// with ID id to c. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.setNDPConfigs(c) + + return nil +} + // PortSeed returns a 32 bit value that can be used as a seed value for port // picking. // diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9a8906a0d..9dae853d0 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1971,13 +1971,15 @@ func TestNICAutoGenAddr(t *testing.T) { // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 1, - }, + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, } e := channel.New(10, 1280, linkAddr1) @@ -1996,21 +1998,35 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the address to resolve (an extra - // 250ms to make sure the address resolves). - // - // TODO(b/140896005): Use events from the - // netstack to know immediately when DAD - // completes. - time.Sleep(time.Second + 250*time.Millisecond) + linkLocalAddr := header.LinkLocalAddr(linkAddr1) - // Should have auto-generated an address and - // resolved (if DAD). + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicid != 1 { + t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) } - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(linkAddr1), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 97a1aec4b..ccd3d030e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,6 +42,31 @@ type transportEndpoints struct { rawEndpoints []RawTransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + eps.mu.Lock() + defer eps.mu.Unlock() + epsByNic, ok := eps.endpoints[id] + if !ok { + return + } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + type endpointsByNic struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint @@ -48,6 +74,16 @@ type endpointsByNic struct { seed uint32 } +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { @@ -127,21 +163,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T return len(epsByNic.endpoints) == 0 } -// unregisterEndpoint unregisters the endpoint with the given id such that it -// won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - eps.mu.Lock() - defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] - if !ok { - return - } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { - return - } - delete(eps.endpoints, id) -} - // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -183,14 +204,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. reuse bool } +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps +} + // reciprocalScale scales a value into range [0, n). // // This is similar to val % n, but faster. @@ -240,6 +274,26 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, v ep.mu.RUnlock() // Don't use defer for performance reasons. } +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } +} + +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } +} + // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { @@ -257,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } return nil } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..203e79f56 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -218,15 +223,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3d10001d8..03be7d3d4 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -489,6 +489,11 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -501,11 +506,6 @@ type ErrorOption struct{} // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int -// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be -// sent out immediately by the transport protocol. For TCP, it determines if the -// Nagle algorithm is on or off. -type DelayOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -673,6 +673,11 @@ func (s *StatCounter) Increment() { s.IncrementBy(1) } +// Decrement minuses one to the counter. +func (s *StatCounter) Decrement() { + s.IncrementBy(^uint64(0)) +} + // Value returns the current value of the counter. func (s *StatCounter) Value() uint64 { return atomic.LoadUint64(&s.count) @@ -881,6 +886,15 @@ type TCPStats struct { // successfully via Listen. PassiveConnectionOpenings *StatCounter + // CurrentEstablished is the number of TCP connections for which the + // current state is either ESTABLISHED or CLOSE-WAIT. + CurrentEstablished *StatCounter + + // EstablishedResets is the number of times TCP connections have made + // a direct transition to the CLOSED state from either the + // ESTABLISHED state or the CLOSE-WAIT state. + EstablishedResets *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index a52262e87..48764b978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3187b336b..33405eb7d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -58,6 +55,7 @@ type endpoint struct { // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -445,13 +449,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - icmpv6.SetChecksum(0) - icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + dataVV := data.ToVectorisedView() + icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, hdr, dataVV, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -760,7 +764,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv }, } - pkt.data = vv.Clone(pkt.views[:]) + pkt.data = vv e.rcvList.PushBack(pkt) e.rcvBufSize += pkt.data.Size() @@ -798,3 +802,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 73cdaa265..ead83b83d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -41,10 +41,6 @@ type packet struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -310,7 +306,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, if ep.cooked { // Cooked packets can simply be queued. - packet.data = vv.Clone(packet.views[:]) + packet.data = vv } else { // Raw packets need their ethernet headers prepended before // queueing. @@ -328,7 +324,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, } combinedVV := buffer.View(ethHeader).ToVectorisedView() combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) + packet.data = combinedVV } packet.timestampNS = ep.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 308f10d24..23922a30e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -42,10 +42,6 @@ type rawPacket struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -609,7 +605,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu combinedVV := netHeader.ToVectorisedView() combinedVV.Append(vv) - pkt.data = combinedVV.Clone(pkt.views[:]) + pkt.data = combinedVV pkt.timestampNS = e.stack.NowNanoseconds() e.rcvList.PushBack(pkt) @@ -641,3 +637,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 844959fa0..1dd00d026 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -297,6 +297,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() + ep.stack.Stats().TCP.CurrentEstablished.Increment() ep.state = StateEstablished ep.mu.Unlock() @@ -399,6 +400,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + switch s.flags { case header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -433,13 +437,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() @@ -519,6 +522,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished // Do the delivery in a separate goroutine so diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 082135374..ca982c451 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -78,9 +78,6 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 - // amss is the maximum segment size advertised by us to the peer. - amss uint16 - // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -445,7 +442,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -805,6 +802,10 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. + if e.state == StateEstablished || e.state == StateCloseWait { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } e.state = StateError e.HardError = err if err != tcpip.ErrConnectionReset { @@ -975,6 +976,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError e.HardError = err @@ -1005,7 +1008,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = StateEstablished + if e.state != StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished + } drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1166,6 +1172,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } // Lock released below. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8b9cb4c33..a1efd8d55 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -287,6 +287,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -411,7 +412,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -504,6 +505,26 @@ type endpoint struct { stats Stats `state:"nosave"` } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS +} + // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { e.workMu.Lock() @@ -550,6 +571,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption @@ -671,7 +693,7 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } @@ -732,7 +754,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } @@ -742,6 +764,7 @@ func (e *endpoint) cleanupLocked() { } e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -752,7 +775,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -1133,16 +1158,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.sndBufMu.Unlock() return nil - default: - return nil - } -} - -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 - switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { atomic.StoreUint32(&e.delay, 0) @@ -1154,6 +1169,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { case tcpip.CorkOption: if v == 0 { atomic.StoreUint32(&e.cork, 0) @@ -1206,7 +1231,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil @@ -1345,6 +1370,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: e.sndBufMu.Lock() v := e.sndBufSize @@ -1357,8 +1383,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1379,13 +1413,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1729,6 +1756,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -2379,6 +2407,23 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index eae17237e..19f003b6b 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -193,8 +193,10 @@ func (e *endpoint) Resume(s *stack.Stack) { if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -265,6 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } fallthrough case StateError: + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..d3f7c9125 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -674,6 +674,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } + s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6d022a266..126f26ed3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -474,6 +474,130 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -1623,7 +1747,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -1671,7 +1795,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -1704,7 +1828,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOpt(tcpip.DelayOption(0)) + c.EP.SetSockOptInt(tcpip.DelayOption, 0) // Check that data is received. second := c.GetPacket() @@ -1741,7 +1865,7 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91c8487f3..03bd5c8fd 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -31,9 +31,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -80,6 +77,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -160,9 +158,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -1195,7 +1199,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) + pkt.data = vv e.rcvList.PushBack(pkt) e.rcvBufSize += vv.Size() @@ -1234,6 +1238,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 0c0eba99e..86df384f8 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -232,7 +232,7 @@ func New(args Args) (*Loader, error) { // this point. Netns is configured before Run() is called. Netstack is // configured using a control uRPC message. Host network is configured inside // Run(). - networkStack, err := newEmptyNetworkStack(args.Conf, k) + networkStack, err := newEmptyNetworkStack(args.Conf, k, k) if err != nil { return nil, fmt.Errorf("creating network: %v", err) } @@ -905,7 +905,7 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { +func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { switch conf.Network { case NetworkHost: return hostinet.NewStack(), nil @@ -923,6 +923,7 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { // Enable raw sockets for users with sufficient // privileges. RawFactory: raw.EndpointFactory{}, + UniqueID: uniqueID, })} // Enable SACK Recovery. diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 26d1cd5ab..2bd12120d 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "container.go", "hook.go", + "state_file.go", "status.go", ], importpath = "gvisor.dev/gvisor/runsc/container", diff --git a/runsc/container/container.go b/runsc/container/container.go index 32510d427..68782c4be 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -17,13 +17,11 @@ package container import ( "context" - "encoding/json" "fmt" "io/ioutil" "os" "os/exec" "os/signal" - "path/filepath" "regexp" "strconv" "strings" @@ -31,7 +29,6 @@ import ( "time" "github.com/cenkalti/backoff" - "github.com/gofrs/flock" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" @@ -41,17 +38,6 @@ import ( "gvisor.dev/gvisor/runsc/specutils" ) -const ( - // metadataFilename is the name of the metadata file relative to the - // container root directory that holds sandbox metadata. - metadataFilename = "meta.json" - - // metadataLockFilename is the name of a lock file in the container - // root directory that is used to prevent concurrent modifications to - // the container state and metadata. - metadataLockFilename = "meta.lock" -) - // validateID validates the container id. func validateID(id string) error { // See libcontainer/factory_linux.go. @@ -99,11 +85,6 @@ type Container struct { // BundleDir is the directory containing the container bundle. BundleDir string `json:"bundleDir"` - // Root is the directory containing the container metadata file. If this - // container is the root container, Root and RootContainerDir will be the - // same. - Root string `json:"root"` - // CreatedAt is the time the container was created. CreatedAt time.Time `json:"createdAt"` @@ -121,21 +102,24 @@ type Container struct { // be 0 if the gofer has been killed. GoferPid int `json:"goferPid"` + // Sandbox is the sandbox this container is running in. It's set when the + // container is created and reset when the sandbox is destroyed. + Sandbox *sandbox.Sandbox `json:"sandbox"` + + // Saver handles load from/save to the state file safely from multiple + // processes. + Saver StateFile `json:"saver"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + // goferIsChild is set if a gofer process is a child of the current process. // // This field isn't saved to json, because only a creator of a gofer // process will have it as a child process. goferIsChild bool - - // Sandbox is the sandbox this container is running in. It's set when the - // container is created and reset when the sandbox is destroyed. - Sandbox *sandbox.Sandbox `json:"sandbox"` - - // RootContainerDir is the root directory containing the metadata file of the - // sandbox root container. It's used to lock in order to serialize creating - // and deleting this Container's metadata directory. If this container is the - // root container, this is the same as Root. - RootContainerDir string } // loadSandbox loads all containers that belong to the sandbox with the given @@ -166,43 +150,35 @@ func loadSandbox(rootDir, id string) ([]*Container, error) { return containers, nil } -// Load loads a container with the given id from a metadata file. id may be an -// abbreviation of the full container id, in which case Load loads the -// container to which id unambiguously refers to. -// Returns ErrNotExist if container doesn't exist. -func Load(rootDir, id string) (*Container, error) { - log.Debugf("Load container %q %q", rootDir, id) - if err := validateID(id); err != nil { +// Load loads a container with the given id from a metadata file. partialID may +// be an abbreviation of the full container id, in which case Load loads the +// container to which id unambiguously refers to. Returns ErrNotExist if +// container doesn't exist. +func Load(rootDir, partialID string) (*Container, error) { + log.Debugf("Load container %q %q", rootDir, partialID) + if err := validateID(partialID); err != nil { return nil, fmt.Errorf("validating id: %v", err) } - cRoot, err := findContainerRoot(rootDir, id) + id, err := findContainerID(rootDir, partialID) if err != nil { // Preserve error so that callers can distinguish 'not found' errors. return nil, err } - // Lock the container metadata to prevent other runsc instances from - // writing to it while we are reading it. - unlock, err := lockContainerMetadata(cRoot) - if err != nil { - return nil, err + state := StateFile{ + RootDir: rootDir, + ID: id, } - defer unlock() + defer state.close() - // Read the container metadata file and create a new Container from it. - metaFile := filepath.Join(cRoot, metadataFilename) - metaBytes, err := ioutil.ReadFile(metaFile) - if err != nil { + c := &Container{} + if err := state.load(c); err != nil { if os.IsNotExist(err) { // Preserve error so that callers can distinguish 'not found' errors. return nil, err } - return nil, fmt.Errorf("reading container metadata file %q: %v", metaFile, err) - } - var c Container - if err := json.Unmarshal(metaBytes, &c); err != nil { - return nil, fmt.Errorf("unmarshaling container metadata from %q: %v", metaFile, err) + return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) } // If the status is "Running" or "Created", check that the sandbox @@ -223,57 +199,37 @@ func Load(rootDir, id string) (*Container, error) { } } - return &c, nil + return c, nil } -func findContainerRoot(rootDir, partialID string) (string, error) { +func findContainerID(rootDir, partialID string) (string, error) { // Check whether the id fully specifies an existing container. - cRoot := filepath.Join(rootDir, partialID) - if _, err := os.Stat(cRoot); err == nil { - return cRoot, nil + stateFile := buildStatePath(rootDir, partialID) + if _, err := os.Stat(stateFile); err == nil { + return partialID, nil } // Now see whether id could be an abbreviation of exactly 1 of the // container ids. If id is ambiguous (it could match more than 1 // container), it is an error. - cRoot = "" ids, err := List(rootDir) if err != nil { return "", err } + rv := "" for _, id := range ids { if strings.HasPrefix(id, partialID) { - if cRoot != "" { - return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, cRoot, id) + if rv != "" { + return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id) } - cRoot = id + rv = id } } - if cRoot == "" { + if rv == "" { return "", os.ErrNotExist } - log.Debugf("abbreviated id %q resolves to full id %q", partialID, cRoot) - return filepath.Join(rootDir, cRoot), nil -} - -// List returns all container ids in the given root directory. -func List(rootDir string) ([]string, error) { - log.Debugf("List containers %q", rootDir) - fs, err := ioutil.ReadDir(rootDir) - if err != nil { - return nil, fmt.Errorf("reading dir %q: %v", rootDir, err) - } - var out []string - for _, f := range fs { - // Filter out directories that do no belong to a container. - cid := f.Name() - if validateID(cid) == nil { - if _, err := os.Stat(filepath.Join(rootDir, cid, metadataFilename)); err == nil { - out = append(out, f.Name()) - } - } - } - return out, nil + log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv) + return rv, nil } // Args is used to configure a new container. @@ -316,44 +272,34 @@ func New(conf *boot.Config, args Args) (*Container, error) { return nil, err } - unlockRoot, err := maybeLockRootContainer(args.Spec, conf.RootDir) - if err != nil { - return nil, err + if err := os.MkdirAll(conf.RootDir, 0711); err != nil { + return nil, fmt.Errorf("creating container root directory: %v", err) } - defer unlockRoot() + + c := &Container{ + ID: args.ID, + Spec: args.Spec, + ConsoleSocket: args.ConsoleSocket, + BundleDir: args.BundleDir, + Status: Creating, + CreatedAt: time.Now(), + Owner: os.Getenv("USER"), + Saver: StateFile{ + RootDir: conf.RootDir, + ID: args.ID, + }, + } + // The Cleanup object cleans up partially created containers when an error + // occurs. Any errors occurring during cleanup itself are ignored. + cu := specutils.MakeCleanup(func() { _ = c.Destroy() }) + defer cu.Clean() // Lock the container metadata file to prevent concurrent creations of // containers with the same id. - containerRoot := filepath.Join(conf.RootDir, args.ID) - unlock, err := lockContainerMetadata(containerRoot) - if err != nil { + if err := c.Saver.lockForNew(); err != nil { return nil, err } - defer unlock() - - // Check if the container already exists by looking for the metadata - // file. - if _, err := os.Stat(filepath.Join(containerRoot, metadataFilename)); err == nil { - return nil, fmt.Errorf("container with id %q already exists", args.ID) - } else if !os.IsNotExist(err) { - return nil, fmt.Errorf("looking for existing container in %q: %v", containerRoot, err) - } - - c := &Container{ - ID: args.ID, - Spec: args.Spec, - ConsoleSocket: args.ConsoleSocket, - BundleDir: args.BundleDir, - Root: containerRoot, - Status: Creating, - CreatedAt: time.Now(), - Owner: os.Getenv("USER"), - RootContainerDir: conf.RootDir, - } - // The Cleanup object cleans up partially created containers when an error occurs. - // Any errors occuring during cleanup itself are ignored. - cu := specutils.MakeCleanup(func() { _ = c.Destroy() }) - defer cu.Clean() + defer c.Saver.unlock() // If the metadata annotations indicate that this container should be // started in an existing sandbox, we must do so. The metadata will @@ -431,7 +377,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { c.changeStatus(Created) // Save the metadata file. - if err := c.save(); err != nil { + if err := c.saveLocked(); err != nil { return nil, err } @@ -451,17 +397,12 @@ func New(conf *boot.Config, args Args) (*Container, error) { func (c *Container) Start(conf *boot.Config) error { log.Debugf("Start container %q", c.ID) - unlockRoot, err := maybeLockRootContainer(c.Spec, c.RootContainerDir) - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlockRoot() + unlock := specutils.MakeCleanup(func() { c.Saver.unlock() }) + defer unlock.Clean() - unlock, err := c.lock() - if err != nil { - return err - } - defer unlock() if err := c.requireStatus("start", Created); err != nil { return err } @@ -509,14 +450,15 @@ func (c *Container) Start(conf *boot.Config) error { } c.changeStatus(Running) - if err := c.save(); err != nil { + if err := c.saveLocked(); err != nil { return err } - // Adjust the oom_score_adj for sandbox. This must be done after - // save(). - err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false) - if err != nil { + // Release lock before adjusting OOM score because the lock is acquired there. + unlock.Clean() + + // Adjust the oom_score_adj for sandbox. This must be done after saveLocked(). + if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil { return err } @@ -529,11 +471,10 @@ func (c *Container) Start(conf *boot.Config) error { // to restore a container from its state file. func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error { log.Debugf("Restore container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if err := c.requireStatus("restore", Created); err != nil { return err @@ -551,7 +492,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str return err } c.changeStatus(Running) - return c.save() + return c.saveLocked() } // Run is a helper that calls Create + Start + Wait. @@ -711,11 +652,10 @@ func (c *Container) Checkpoint(f *os.File) error { // The call only succeeds if the container's status is created or running. func (c *Container) Pause() error { log.Debugf("Pausing container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if c.Status != Created && c.Status != Running { return fmt.Errorf("cannot pause container %q in state %v", c.ID, c.Status) @@ -725,18 +665,17 @@ func (c *Container) Pause() error { return fmt.Errorf("pausing container: %v", err) } c.changeStatus(Paused) - return c.save() + return c.saveLocked() } // Resume unpauses the container and its kernel. // The call only succeeds if the container's status is paused. func (c *Container) Resume() error { log.Debugf("Resuming container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if c.Status != Paused { return fmt.Errorf("cannot resume container %q in state %v", c.ID, c.Status) @@ -745,7 +684,7 @@ func (c *Container) Resume() error { return fmt.Errorf("resuming container: %v", err) } c.changeStatus(Running) - return c.save() + return c.saveLocked() } // State returns the metadata of the container. @@ -773,6 +712,17 @@ func (c *Container) Processes() ([]*control.Process, error) { func (c *Container) Destroy() error { log.Debugf("Destroy container %q", c.ID) + if err := c.Saver.lock(); err != nil { + return err + } + defer func() { + c.Saver.unlock() + c.Saver.close() + }() + + // Stored for later use as stop() sets c.Sandbox to nil. + sb := c.Sandbox + // We must perform the following cleanup steps: // * stop the container and gofer processes, // * remove the container filesystem on the host, and @@ -782,48 +732,43 @@ func (c *Container) Destroy() error { // do our best to perform all of the cleanups. Hence, we keep a slice // of errors return their concatenation. var errs []string - - unlock, err := maybeLockRootContainer(c.Spec, c.RootContainerDir) - if err != nil { - return err - } - defer unlock() - - // Stored for later use as stop() sets c.Sandbox to nil. - sb := c.Sandbox - if err := c.stop(); err != nil { err = fmt.Errorf("stopping container: %v", err) log.Warningf("%v", err) errs = append(errs, err.Error()) } - if err := os.RemoveAll(c.Root); err != nil && !os.IsNotExist(err) { - err = fmt.Errorf("deleting container root directory %q: %v", c.Root, err) + if err := c.Saver.destroy(); err != nil { + err = fmt.Errorf("deleting container state files: %v", err) log.Warningf("%v", err) errs = append(errs, err.Error()) } c.changeStatus(Stopped) - // Adjust oom_score_adj for the sandbox. This must be done after the - // container is stopped and the directory at c.Root is removed. - // We must test if the sandbox is nil because Destroy should be - // idempotent. - if sb != nil { - if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil { + // Adjust oom_score_adj for the sandbox. This must be done after the container + // is stopped and the directory at c.Root is removed. Adjustment can be + // skipped if the root container is exiting, because it brings down the entire + // sandbox. + // + // Use 'sb' to tell whether it has been executed before because Destroy must + // be idempotent. + if sb != nil && !isRoot(c.Spec) { + if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil { errs = append(errs, err.Error()) } } // "If any poststop hook fails, the runtime MUST log a warning, but the - // remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec. - // Based on the OCI, "The post-stop hooks MUST be called after the container is - // deleted but before the delete operation returns" + // remaining hooks and lifecycle continue as if the hook had + // succeeded" - OCI spec. + // + // Based on the OCI, "The post-stop hooks MUST be called after the container + // is deleted but before the delete operation returns" // Run it here to: // 1) Conform to the OCI. - // 2) Make sure it only runs once, because the root has been deleted, the container - // can't be loaded again. + // 2) Make sure it only runs once, because the root has been deleted, the + // container can't be loaded again. if c.Spec.Hooks != nil { executeHooksBestEffort(c.Spec.Hooks.Poststop, c.State()) } @@ -834,18 +779,13 @@ func (c *Container) Destroy() error { return fmt.Errorf(strings.Join(errs, "\n")) } -// save saves the container metadata to a file. +// saveLocked saves the container metadata to a file. // // Precondition: container must be locked with container.lock(). -func (c *Container) save() error { +func (c *Container) saveLocked() error { log.Debugf("Save container %q", c.ID) - metaFile := filepath.Join(c.Root, metadataFilename) - meta, err := json.Marshal(c) - if err != nil { - return fmt.Errorf("invalid container metadata: %v", err) - } - if err := ioutil.WriteFile(metaFile, meta, 0640); err != nil { - return fmt.Errorf("writing container metadata: %v", err) + if err := c.Saver.saveLocked(c); err != nil { + return fmt.Errorf("saving container metadata: %v", err) } return nil } @@ -1106,48 +1046,6 @@ func (c *Container) requireStatus(action string, statuses ...Status) error { return fmt.Errorf("cannot %s container %q in state %s", action, c.ID, c.Status) } -// lock takes a file lock on the container metadata lock file. -func (c *Container) lock() (func() error, error) { - return lockContainerMetadata(filepath.Join(c.Root, c.ID)) -} - -// lockContainerMetadata takes a file lock on the metadata lock file in the -// given container root directory. -func lockContainerMetadata(containerRootDir string) (func() error, error) { - if err := os.MkdirAll(containerRootDir, 0711); err != nil { - return nil, fmt.Errorf("creating container root directory %q: %v", containerRootDir, err) - } - f := filepath.Join(containerRootDir, metadataLockFilename) - l := flock.NewFlock(f) - if err := l.Lock(); err != nil { - return nil, fmt.Errorf("acquiring lock on container lock file %q: %v", f, err) - } - return l.Unlock, nil -} - -// maybeLockRootContainer locks the sandbox root container. It is used to -// prevent races to create and delete child container sandboxes. -func maybeLockRootContainer(spec *specs.Spec, rootDir string) (func() error, error) { - if isRoot(spec) { - return func() error { return nil }, nil - } - - sbid, ok := specutils.SandboxID(spec) - if !ok { - return nil, fmt.Errorf("no sandbox ID found when locking root container") - } - sb, err := Load(rootDir, sbid) - if err != nil { - return nil, err - } - - unlock, err := sb.lock() - if err != nil { - return nil, err - } - return unlock, nil -} - func isRoot(spec *specs.Spec) bool { return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer } @@ -1170,7 +1068,12 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error { func (c *Container) adjustGoferOOMScoreAdj() error { if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil { if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil { - return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + // Ignore NotExist error because it can be returned when the sandbox + // exited while OOM score was being adjusted. + if !os.IsNotExist(err) { + return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + } + log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid) } } @@ -1252,7 +1155,12 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Set the lowest of all containers oom_score_adj to the sandbox. if err := setOOMScoreAdj(s.Pid, lowScore); err != nil { - return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) + // Ignore NotExist error because it can be returned when the sandbox + // exited while OOM score was being adjusted. + if !os.IsNotExist(err) { + return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) + } + log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid) } return nil diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index c4c56b2e0..07eacaac0 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -1548,7 +1548,8 @@ func TestAbbreviatedIDs(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfigWithRoot(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir cids := []string{ "foo-" + testutil.UniqueContainerID(), diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 9e02a825e..a5a62378c 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -60,13 +60,8 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { } func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { - // Setup root dir if one hasn't been provided. if len(conf.RootDir) == 0 { - rootDir, err := testutil.SetupRootDir() - if err != nil { - return nil, nil, fmt.Errorf("error creating root dir: %v", err) - } - conf.RootDir = rootDir + panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } var containers []*Container @@ -78,7 +73,6 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*C for _, b := range bundles { os.RemoveAll(b) } - os.RemoveAll(conf.RootDir) } for i, spec := range specs { bundleDir, err := testutil.SetupBundleDir(spec) @@ -144,6 +138,13 @@ func TestMultiContainerSanity(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep) @@ -175,6 +176,13 @@ func TestMultiPIDNS(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} testSpecs, ids := createSpecs(sleep, sleep) @@ -213,6 +221,13 @@ func TestMultiPIDNSPath(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} testSpecs, ids := createSpecs(sleep, sleep, sleep) @@ -268,13 +283,21 @@ func TestMultiPIDNSPath(t *testing.T) { } func TestMultiContainerWait(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // The first container should run the entire duration of the test. cmd1 := []string{"sleep", "100"} // We'll wait on the second container, which is much shorter lived. cmd2 := []string{"sleep", "1"} specs, ids := createSpecs(cmd1, cmd2) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -344,12 +367,14 @@ func TestExecWait(t *testing.T) { } defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + // The first container should run the entire duration of the test. cmd1 := []string{"sleep", "100"} // We'll wait on the second container, which is much shorter lived. cmd2 := []string{"sleep", "1"} specs, ids := createSpecs(cmd1, cmd2) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -432,7 +457,15 @@ func TestMultiContainerMount(t *testing.T) { }) // Setup the containers. + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + containers, cleanup, err := startContainers(conf, sps, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -454,6 +487,13 @@ func TestMultiContainerSignal(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep) @@ -548,6 +588,13 @@ func TestMultiContainerDestroy(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // First container will remain intact while the second container is killed. podSpecs, ids := createSpecs( []string{"sleep", "100"}, @@ -599,13 +646,21 @@ func TestMultiContainerDestroy(t *testing.T) { } func TestMultiContainerProcesses(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Note: use curly braces to keep 'sh' process around. Otherwise, shell // will just execve into 'sleep' and both containers will look the // same. specs, ids := createSpecs( []string{"sleep", "100"}, []string{"sh", "-c", "{ sleep 100; }"}) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -650,6 +705,15 @@ func TestMultiContainerProcesses(t *testing.T) { // TestMultiContainerKillAll checks that all process that belong to a container // are killed when SIGKILL is sent to *all* processes in that container. func TestMultiContainerKillAll(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + for _, tc := range []struct { killContainer bool }{ @@ -665,7 +729,6 @@ func TestMultiContainerKillAll(t *testing.T) { specs, ids := createSpecs( []string{app, "task-tree", "--depth=2", "--width=2"}, []string{app, "task-tree", "--depth=4", "--width=2"}) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -739,19 +802,13 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) { specs, ids := createSpecs( []string{"/bin/sleep", "100"}, []string{"/bin/sleep", "100"}) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfigWithRoot(rootDir) - // Create and start root container. - rootBundleDir, err := testutil.SetupBundleDir(specs[0]) + conf := testutil.TestConfig() + rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(rootBundleDir) rootArgs := Args{ @@ -800,19 +857,12 @@ func TestMultiContainerDestroyStarting(t *testing.T) { } specs, ids := createSpecs(cmds...) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfigWithRoot(rootDir) - - // Create and start root container. - rootBundleDir, err := testutil.SetupBundleDir(specs[0]) + conf := testutil.TestConfig() + rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(rootBundleDir) rootArgs := Args{ @@ -886,9 +936,17 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) { script := fmt.Sprintf("if [ -f %q ]; then exit 1; else touch %q; fi", filename, filename) cmd := []string{"sh", "-c", script} + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Make sure overlay is enabled, and none of the root filesystems are // read-only, otherwise we won't be able to create the file. - conf := testutil.TestConfig() conf.Overlay = true specs, ids := createSpecs(cmdRoot, cmd, cmd) for _, s := range specs { @@ -941,26 +999,21 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { } allSpecs, allIDs := createSpecs(cmds...) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - // Split up the specs and IDs. rootSpec := allSpecs[0] rootID := allIDs[0] childrenSpecs := allSpecs[1:] childrenIDs := allIDs[1:] - bundleDir, err := testutil.SetupBundleDir(rootSpec) + conf := testutil.TestConfig() + rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf) if err != nil { - t.Fatalf("error setting up bundle dir: %v", err) + t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(bundleDir) // Start root container. - conf := testutil.TestConfigWithRoot(rootDir) rootArgs := Args{ ID: rootID, Spec: rootSpec, @@ -1029,6 +1082,13 @@ func TestMultiContainerSharedMount(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1137,6 +1197,13 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1197,6 +1264,13 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1300,8 +1374,14 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { // Test that unsupported pod mounts options are ignored when matching master and // slave mounts. func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() - t.Logf("Running test with conf: %+v", conf) + conf.RootDir = rootDir // Setup the containers. sleep := []string{"/bin/sleep", "100"} @@ -1376,6 +1456,15 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { Type: "tmpfs", } + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Create the specs. specs, ids := createSpecs( []string{"sleep", "1000"}, @@ -1386,7 +1475,6 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { specs[1].Mounts = append(specs[2].Mounts, sharedMnt, writeableMnt) specs[2].Mounts = append(specs[1].Mounts, sharedMnt) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1405,9 +1493,17 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { // Test that container is destroyed when Gofer is killed. func TestMultiContainerGoferKilled(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep, sleep) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1483,7 +1579,15 @@ func TestMultiContainerGoferKilled(t *testing.T) { func TestMultiContainerLoadSandbox(t *testing.T) { sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep, sleep) + + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir // Create containers for the sandbox. wants, cleanup, err := startContainers(conf, specs, ids) @@ -1576,7 +1680,15 @@ func TestMultiContainerRunNonRoot(t *testing.T) { Type: "bind", }) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + pod, cleanup, err := startContainers(conf, podSpecs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go new file mode 100644 index 000000000..d95151ea5 --- /dev/null +++ b/runsc/container/state_file.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package container + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sync" + + "github.com/gofrs/flock" + "gvisor.dev/gvisor/pkg/log" +) + +const stateFileExtension = ".state" + +// StateFile handles load from/save to container state safely from multiple +// processes. It uses a lock file to provide synchronization between operations. +// +// The lock file is located at: "${s.RootDir}/${s.ID}.lock". +// The state file is located at: "${s.RootDir}/${s.ID}.state". +type StateFile struct { + // RootDir is the directory containing the container metadata file. + RootDir string `json:"rootDir"` + + // ID is the container ID. + ID string `json:"id"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + + once sync.Once + flock *flock.Flock +} + +// List returns all container ids in the given root directory. +func List(rootDir string) ([]string, error) { + log.Debugf("List containers %q", rootDir) + list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension)) + if err != nil { + return nil, err + } + var out []string + for _, path := range list { + // Filter out files that do no belong to a container. + fileName := filepath.Base(path) + if len(fileName) < len(stateFileExtension) { + panic(fmt.Sprintf("invalid file match %q", path)) + } + // Remove the extension. + cid := fileName[:len(fileName)-len(stateFileExtension)] + if validateID(cid) == nil { + out = append(out, cid) + } + } + return out, nil +} + +// lock globally locks all locking operations for the container. +func (s *StateFile) lock() error { + s.once.Do(func() { + s.flock = flock.NewFlock(s.lockPath()) + }) + + if err := s.flock.Lock(); err != nil { + return fmt.Errorf("acquiring lock on %q: %v", s.flock, err) + } + return nil +} + +// lockForNew acquires the lock and checks if the state file doesn't exist. This +// is done to ensure that more than one creation didn't race to create +// containers with the same ID. +func (s *StateFile) lockForNew() error { + if err := s.lock(); err != nil { + return err + } + + // Checks if the container already exists by looking for the metadata file. + if _, err := os.Stat(s.statePath()); err == nil { + s.unlock() + return fmt.Errorf("container already exists") + } else if !os.IsNotExist(err) { + s.unlock() + return fmt.Errorf("looking for existing container: %v", err) + } + return nil +} + +// unlock globally unlocks all locking operations for the container. +func (s *StateFile) unlock() error { + if !s.flock.Locked() { + panic("unlock called without lock held") + } + + if err := s.flock.Unlock(); err != nil { + log.Warningf("Error to release lock on %q: %v", s.flock, err) + return fmt.Errorf("releasing lock on %q: %v", s.flock, err) + } + return nil +} + +// saveLocked saves 'v' to the state file. +// +// Preconditions: lock() must been called before. +func (s *StateFile) saveLocked(v interface{}) error { + if !s.flock.Locked() { + panic("saveLocked called without lock held") + } + + meta, err := json.Marshal(v) + if err != nil { + return err + } + if err := ioutil.WriteFile(s.statePath(), meta, 0640); err != nil { + return fmt.Errorf("writing json file: %v", err) + } + return nil +} + +func (s *StateFile) load(v interface{}) error { + if err := s.lock(); err != nil { + return err + } + defer s.unlock() + + metaBytes, err := ioutil.ReadFile(s.statePath()) + if err != nil { + return err + } + return json.Unmarshal(metaBytes, &v) +} + +func (s *StateFile) close() error { + if s.flock == nil { + return nil + } + if s.flock.Locked() { + panic("Closing locked file") + } + return s.flock.Close() +} + +func buildStatePath(rootDir, id string) string { + return filepath.Join(rootDir, id+stateFileExtension) +} + +// statePath is the full path to the state file. +func (s *StateFile) statePath() string { + return buildStatePath(s.RootDir, s.ID) +} + +// lockPath is the full path to the lock file. +func (s *StateFile) lockPath() string { + return filepath.Join(s.RootDir, s.ID+".lock") +} + +// destroy deletes all state created by the stateFile. It may be called with the +// lock file held. In that case, the lock file must still be unlocked and +// properly closed after destroy returns. +func (s *StateFile) destroy() error { + if err := os.Remove(s.statePath()); err != nil && !os.IsNotExist(err) { + return err + } + if err := os.Remove(s.lockPath()); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 3fceecb3d..18b853e2e 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -601,7 +601,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), GID: p9.GID(stat.Gid), - NLink: stat.Nlink, + NLink: uint64(stat.Nlink), RDev: stat.Rdev, Size: uint64(stat.Size), BlockSize: uint64(stat.Blksize), diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 26467bdc7..9632776d2 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -151,13 +151,6 @@ func TestConfig() *boot.Config { } } -// TestConfigWithRoot returns the default configuration to use in tests. -func TestConfigWithRoot(rootDir string) *boot.Config { - conf := TestConfig() - conf.RootDir = rootDir - return conf -} - // NewSpecWithArgs creates a simple spec with the given args suitable for use // in tests. func NewSpecWithArgs(args ...string) *specs.Spec { diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh new file mode 100755 index 000000000..fb82b2491 --- /dev/null +++ b/scripts/runtime_tests.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +if [ ! -v RUNTIME ]; then + echo 'Must set $RUNTIME' >&2 + exit 1 +fi + +install_runsc_for_test runtimes +test_runsc "//test/runtimes:${RUNTIME}_test" diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 7cc0de129..28064e557 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -175,6 +175,9 @@ func TestCheckpointRestore(t *testing.T) { t.Fatal(err) } + // TODO(b/143498576): Remove after github.com/moby/moby/issues/38963 is fixed. + time.Sleep(1 * time.Second) + if err := d.Restore("test"); err != nil { t.Fatal("docker restore failed:", err) } diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 6cd378a1b..126f0975a 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -40,6 +40,15 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -84,7 +93,6 @@ func TestOOMScoreAdjSingle(t *testing.T) { s := testutil.NewSpecWithArgs("sleep", "1000") s.Process.OOMScoreAdj = testCase.OOMScoreAdj - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -123,6 +131,15 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -240,7 +257,6 @@ func TestOOMScoreAdjMulti(t *testing.T) { } } - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -327,13 +343,8 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { } func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { - // Setup root dir if one hasn't been provided. if len(conf.RootDir) == 0 { - rootDir, err := testutil.SetupRootDir() - if err != nil { - return nil, nil, fmt.Errorf("error creating root dir: %v", err) - } - conf.RootDir = rootDir + panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } var containers []*container.Container @@ -345,7 +356,6 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*c for _, b := range bundles { os.RemoveAll(b) } - os.RemoveAll(conf.RootDir) } for i, spec := range specs { bundleDir, err := testutil.SetupBundleDir(spec) diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 2e125525b..367295206 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -16,32 +16,32 @@ go_binary( ) runtime_test( + name = "go1.12", blacklist_file = "blacklist_go1.12.csv", - image = "gcr.io/gvisor-presubmit/go1.12", lang = "go", ) runtime_test( + name = "java11", blacklist_file = "blacklist_java11.csv", - image = "gcr.io/gvisor-presubmit/java11", lang = "java", ) runtime_test( + name = "nodejs12.4.0", blacklist_file = "blacklist_nodejs12.4.0.csv", - image = "gcr.io/gvisor-presubmit/nodejs12.4.0", lang = "nodejs", ) runtime_test( + name = "php7.3.6", blacklist_file = "blacklist_php7.3.6.csv", - image = "gcr.io/gvisor-presubmit/php7.3.6", lang = "php", ) runtime_test( + name = "python3.7.3", blacklist_file = "blacklist_python3.7.3.csv", - image = "gcr.io/gvisor-presubmit/python3.7.3", lang = "python", ) diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index 7c11624b4..d458df1fd 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -2,32 +2,48 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") -# runtime_test is a macro that will create targets to run the given test target -# with different runtime options. def runtime_test( + name, lang, - image, + image_repo = "gcr.io/gvisor-presubmit", + image_name = None, + blacklist_file = None, shard_count = 50, - size = "enormous", - blacklist_file = ""): + size = "enormous"): + """Generates sh_test and blacklist test targets for a given runtime. + + Args: + name: The name of the runtime being tested. Typically, the lang + version. + This is used in the names of the generated test targets. + lang: The language being tested. + image_repo: The docker repository containing the proctor image to run. + i.e., the prefix to the fully qualified docker image id. + image_name: The name of the image in the image_repo. + Defaults to the test name. + blacklist_file: A test blacklist to pass to the runtime test's runner. + shard_count: See Bazel common test attributes. + size: See Bazel common test attributes. + """ + if image_name == None: + image_name = name args = [ "--lang", lang, "--image", - image, + "/".join([image_repo, image_name]), ] data = [ ":runner", ] - if blacklist_file != "": + if blacklist_file: args += ["--blacklist_file", "test/runtimes/" + blacklist_file] data += [blacklist_file] # Add a test that the blacklist parses correctly. - blacklist_test(lang, blacklist_file) + blacklist_test(name, blacklist_file) sh_test( - name = lang + "_test", + name = name + "_test", srcs = ["runner.sh"], args = args, data = data, @@ -35,15 +51,14 @@ def runtime_test( shard_count = shard_count, tags = [ # Requires docker and runsc to be configured before the test runs. - "manual", "local", ], ) -def blacklist_test(lang, blacklist_file): +def blacklist_test(name, blacklist_file): """Test that a blacklist parses correctly.""" go_test( - name = lang + "_blacklist_test", + name = name + "_blacklist_test", embed = [":runner"], srcs = ["blacklist_test.go"], args = ["--blacklist_file", "test/runtimes/" + blacklist_file], diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index b869ca6f9..833fbaa09 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1568,11 +1568,14 @@ cc_binary( srcs = ["proc_net.cc"], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index 328192a05..427c42ede 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -17,7 +17,6 @@ #include <algorithm> #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc index b6cdb3f4f..7bcd91e9e 100644 --- a/test/syscalls/linux/accept_bind_stream.cc +++ b/test/syscalls/linux/accept_bind_stream.cc @@ -17,7 +17,6 @@ #include <algorithm> #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/bind.cc b/test/syscalls/linux/bind.cc index de8cca53b..9547c4ab2 100644 --- a/test/syscalls/linux/bind.cc +++ b/test/syscalls/linux/bind.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 498c45f16..de1611c21 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -24,7 +24,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc index 98032ac19..bfe1da82e 100644 --- a/test/syscalls/linux/connect_external.cc +++ b/test/syscalls/linux/connect_external.cc @@ -22,7 +22,6 @@ #include <tuple> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index 85734c290..581f03533 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -533,6 +533,47 @@ TEST(ExecTest, CloexecEventfd) { W_EXITCODE(0, 0), ""); } +constexpr int kLinuxMaxSymlinks = 40; + +TEST(ExecTest, SymlinkLimitExceeded) { + std::string path = WorkloadPath(kBasicWorkload); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> symlinks; + for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) { + symlinks.push_back( + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path))); + path = symlinks[i].path(); + } + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { + std::string tmp_dir = "/tmp"; + std::string interpreter_path = "/bin/echo"; + TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + tmp_dir, absl::StrCat("#!", interpreter_path), 0755)); + std::string script_path = script.path(); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> interpreter_symlinks; + std::vector<TempPath> script_symlinks; + for (int i = 0; i < kLinuxMaxSymlinks; i++) { + interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); + interpreter_path = interpreter_symlinks[i].path(); + script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, script_path))); + script_path = script_symlinks[i].path(); + } + + CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), ""); +} + TEST(ExecveatTest, BasicWithFDCWD) { std::string path = WorkloadPath(kBasicWorkload); CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), @@ -542,14 +583,26 @@ TEST(ExecveatTest, BasicWithFDCWD) { TEST(ExecveatTest, Basic) { std::string absolute_path = WorkloadPath(kBasicWorkload); std::string parent_dir = std::string(Dirname(absolute_path)); - std::string relative_path = std::string(Basename(absolute_path)); + std::string base = std::string(Basename(absolute_path)); const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); - CheckExecveat(dirfd.get(), relative_path, {absolute_path}, {}, /*flags=*/0, + CheckExecveat(dirfd.get(), base, {absolute_path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); } +TEST(ExecveatTest, FDNotADirectory) { + std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), base, {absolute_path}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOTDIR); +} + TEST(ExecveatTest, AbsolutePathWithFDCWD) { std::string path = WorkloadPath(kBasicWorkload); CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, @@ -564,6 +617,152 @@ TEST(ExecveatTest, AbsolutePath) { absl::StrCat(path, "\n")); } +TEST(ExecveatTest, EmptyPathBasic) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0), + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, EmptyPathWithDirFD) { + std::string path = WorkloadPath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), "", {path}, {}, + AT_EMPTY_PATH, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, EACCES); +} + +TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( + fd.get(), "", {path}, {}, /*flags=*/0, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { + std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + CheckExecveat(dirfd.get(), base, {absolute_path}, {}, AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); +} + +TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + std::string base = std::string(Basename(link.path())); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + AT_SYMLINK_NOFOLLOW, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + std::string path = link.path(); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(AT_FDCWD, path, {path}, {}, + AT_SYMLINK_NOFOLLOW, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) { + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + std::string path = link.path(); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH | AT_SYMLINK_NOFOLLOW, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, SymlinkNoFollowIgnoreSymlinkAncestor) { + TempPath parent_link = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", "/bin")); + std::string path_with_symlink = JoinPath(parent_link.path(), "echo"); + + CheckExecveat(AT_FDCWD, path_with_symlink, {path_with_symlink}, {}, + AT_SYMLINK_NOFOLLOW, ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) { + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/bin", O_DIRECTORY)); + + CheckExecveat(dirfd.get(), "echo", {"echo"}, {}, AT_SYMLINK_NOFOLLOW, + ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, BasicWithCloexecFD) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { + std::string path = WorkloadPath(kExitScript); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {}, + AT_EMPTY_PATH, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) { + std::string absolute_path = WorkloadPath(kExitScript); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InvalidFlags) { + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( + /*dirfd=*/-1, "", {}, {}, /*flags=*/0xFFFF, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, EINVAL); +} + // Priority consistent across calls to execve() TEST(GetpriorityTest, ExecveMaintainsPriority) { int prio = 16; diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 36efabcae..4d155b618 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -32,7 +32,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index 4948a76f0..c4f8bff08 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -25,7 +25,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index 08ff4052c..7fd0ea20c 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -25,7 +25,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/file_descriptor.h" #include "test/util/logging.h" #include "test/util/memory_util.h" diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc index a6e20f9c3..ff2f49863 100644 --- a/test/syscalls/linux/memory_accounting.cc +++ b/test/syscalls/linux/memory_accounting.cc @@ -16,7 +16,6 @@ #include <map> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index fcf64ee59..92ae55eec 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -130,6 +130,20 @@ void CookedPacketTest::SetUp() { GTEST_SKIP(); } + if (!IsRunningOnGvisor()) { + FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + char enabled; + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), SyscallSucceeds()); } diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 10e2a6dfc..c0b354e65 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -20,7 +20,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/notification.h" #include "absl/time/clock.h" diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 5e3eb1735..2cecf2e5f 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -20,7 +20,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index eebd129f2..f7ea44054 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -22,7 +22,6 @@ #include <string> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc index aac960130..c9246367d 100644 --- a/test/syscalls/linux/preadv2.cc +++ b/test/syscalls/linux/preadv2.cc @@ -21,7 +21,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/memory/memory.h" #include "test/syscalls/linux/file_base.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index efdaf202b..65bad06d4 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -12,8 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> +#include <errno.h> +#include <netinet/in.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/syscall.h> +#include <sys/types.h> + #include "gtest/gtest.h" -#include "gtest/gtest.h" +#include "absl/strings/str_split.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" @@ -57,6 +67,265 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, + const std::string &type, + const std::string &item) { + std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n'); + + // /proc/net/snmp prints a line of headers followed by a line of metrics. + // Only search the headers. + for (unsigned i = 0; i < snmp_vec.size(); i = i + 2) { + if (!absl::StartsWith(snmp_vec[i], type)) continue; + + std::vector<std::string> fields = + absl::StrSplit(snmp_vec[i], ' ', absl::SkipWhitespace()); + + EXPECT_TRUE((i + 1) < snmp_vec.size()); + std::vector<std::string> values = + absl::StrSplit(snmp_vec[i + 1], ' ', absl::SkipWhitespace()); + + EXPECT_TRUE(!fields.empty() && fields.size() == values.size()); + + // Metrics start at the first index. + for (unsigned j = 1; j < fields.size(); j++) { + if (fields[j] == item) { + uint64_t val; + if (!absl::SimpleAtoi(values[j], &val)) { + return PosixError(EINVAL, + absl::StrCat("field is not a number: ", values[j])); + } + + return val; + } + } + } + // We should never get here. + return PosixError( + EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp)); +} + +TEST(ProcNetSnmp, TcpReset_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldAttemptFails; + uint64_t oldActiveOpens; + uint64_t oldOutRsts; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + oldOutRsts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); + oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(1234), + }; + + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(connect(s.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallFailsWithErrno(ECONNREFUSED)); + + uint64_t newAttemptFails; + uint64_t newActiveOpens; + uint64_t newOutRsts; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + newOutRsts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); + newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); + + EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); + EXPECT_EQ(oldOutRsts, newOutRsts - 1); + EXPECT_EQ(oldAttemptFails, newAttemptFails - 1); +} + +TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldEstabResets; + uint64_t oldActiveOpens; + uint64_t oldPassiveOpens; + uint64_t oldCurrEstab; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); + oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); + + FileDescriptor s_listen = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = 0, + }; + + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(bind(s_listen.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = sizeof(sin); + ASSERT_THAT( + getsockname(s_listen.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + SyscallSucceeds()); + + FileDescriptor s_connect = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_THAT(connect(s_connect.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + + auto s_accept = + ASSERT_NO_ERRNO_AND_VALUE(Accept(s_listen.get(), nullptr, nullptr)); + + uint64_t newEstabResets; + uint64_t newActiveOpens; + uint64_t newPassiveOpens; + uint64_t newCurrEstab; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); + newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + + EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); + EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1); + EXPECT_EQ(oldCurrEstab, newCurrEstab - 2); + + // Send 1 byte from client to server. + ASSERT_THAT(send(s_connect.get(), "a", 1, 0), SyscallSucceedsWithValue(1)); + + constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. + + // Wait until server-side fd sees the data on its side but don't read it. + struct pollfd poll_fd = {s_accept.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close server-side fd without reading the data which leads to a RST + // packet sent to client side. + s_accept.reset(-1); + + // Wait until client-side fd sees RST packet. + struct pollfd poll_fd1 = {s_connect.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd1, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close client-side fd. + s_connect.reset(-1); + + // Wait until the process of the netstack. + absl::SleepFor(absl::Seconds(1)); + + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + newEstabResets = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); + + EXPECT_EQ(oldCurrEstab, newCurrEstab); + EXPECT_EQ(oldEstabResets, newEstabResets - 2); +} + +TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldOutDatagrams; + uint64_t oldNoPorts; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + oldNoPorts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(4444), + }; + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceedsWithValue(1)); + + uint64_t newOutDatagrams; + uint64_t newNoPorts; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + newNoPorts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); + + EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); + EXPECT_EQ(oldNoPorts, newNoPorts - 1); +} + +TEST(ProcNetSnmp, UdpIn) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + const DisableSave ds; + + uint64_t oldOutDatagrams; + uint64_t oldInDatagrams; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); + + std::cerr << "snmp: " << std::endl << snmp << std::endl; + FileDescriptor server = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(0), + }; + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(bind(server.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + // Get the port bound by the server socket. + socklen_t addrlen = sizeof(sin); + ASSERT_THAT( + getsockname(server.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + SyscallSucceeds()); + + FileDescriptor client = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + ASSERT_THAT( + sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceedsWithValue(1)); + + char buf[128]; + ASSERT_THAT(recvfrom(server.get(), buf, sizeof(buf), 0, NULL, NULL), + SyscallSucceedsWithValue(1)); + + uint64_t newOutDatagrams; + uint64_t newInDatagrams; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + std::cerr << "new snmp: " << std::endl << snmp << std::endl; + newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); + + EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); + EXPECT_EQ(oldInDatagrams, newInDatagrams - 1); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index f61795592..2659f6a98 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -18,7 +18,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc index 369df8e0e..f06f1a24b 100644 --- a/test/syscalls/linux/proc_net_udp.cc +++ b/test/syscalls/linux/proc_net_udp.cc @@ -18,7 +18,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 83dbd1364..66db0acaa 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index e1603fc2d..b48fe540d 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index f6a0fc96c..1dbc0d6df 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -21,7 +21,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index f327ec3a9..4069cbc7e 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc index 35d2dd9e3..9658f7d42 100644 --- a/test/syscalls/linux/readv_common.cc +++ b/test/syscalls/linux/readv_common.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc index 3c315cc02..9b6972201 100644 --- a/test/syscalls/linux/readv_socket.cc +++ b/test/syscalls/linux/readv_socket.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index c9d76c2e2..5b474ff32 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -17,7 +17,6 @@ #include <string> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/util/capability_util.h" #include "test/util/cleanup.h" diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index 88c010aec..e06a2666d 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -21,7 +21,6 @@ #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/time.h" #include "test/syscalls/linux/base_poll_test.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index 1c56540bc..3331288b7 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -185,7 +185,7 @@ TEST_P(SendFileTest, Shutdown) { // Create a socket. std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, released below. + FileDescriptor server(std::get<1>(fds)); // non-const, reset below. // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -210,14 +210,14 @@ TEST_P(SendFileTest, Shutdown) { // checking the contents (other tests do that), so we just re-use the same // buffer as above. ScopedThread t([&]() { - int done = 0; + size_t done = 0; while (done < data.size()) { - int n = read(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(server.get(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - ASSERT_THAT(close(server.release()), SyscallSucceeds()); + server.reset(); }); // Continuously stream from the file to the socket. Note we do not assert diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc index 69b6e4f90..6fd3989a4 100644 --- a/test/syscalls/linux/sigaltstack.cc +++ b/test/syscalls/linux/sigaltstack.cc @@ -22,7 +22,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/cleanup.h" #include "test/util/fs_util.h" #include "test/util/multiprocess_util.h" diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 9379d5878..09ecad34a 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -24,7 +24,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/synchronization/mutex.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc index d20821cac..6b27f6eab 100644 --- a/test/syscalls/linux/socket_bind_to_device.cc +++ b/test/syscalls/linux/socket_bind_to_device.cc @@ -32,7 +32,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index 4d2400328..5767181a1 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -33,7 +33,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index a7365d139..e4641c62e 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -33,7 +33,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc index 00c50d1bf..d7ce57566 100644 --- a/test/syscalls/linux/socket_blocking.cc +++ b/test/syscalls/linux/socket_blocking.cc @@ -20,7 +20,6 @@ #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 51d614639..e8f24a59e 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 322ee07ad..ab375aaaf 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -16,6 +16,7 @@ #include <netinet/in.h> #include <poll.h> #include <string.h> +#include <sys/epoll.h> #include <sys/socket.h> #include <atomic> @@ -516,6 +517,112 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // TODO(b/141211329): endpointsByNic.seed has to be saved/restored. + const DisableSave ds141211329; + + // Create listening sockets. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10; + FileDescriptor client_fds[kConnectAttempts]; + + // Do the first run without save/restore. + DisableSave ds; + for (int i = 0; i < kConnectAttempts; i++) { + client_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + ds.reset(); + + // Check that a mapping of client and server sockets has + // not been change after save/restore. + for (int i = 0; i < kConnectAttempts; i++) { + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + + int epollfd; + ASSERT_THAT(epollfd = epoll_create1(0), SyscallSucceeds()); + + for (int i = 0; i < kThreadCount; i++) { + int fd = listener_fds[i].get(); + struct epoll_event ev; + ev.data.fd = fd; + ev.events = EPOLLIN; + ASSERT_THAT(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev), SyscallSucceeds()); + } + + std::map<uint16_t, int> portToFD; + + for (int i = 0; i < kConnectAttempts * 2; i++) { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + struct epoll_event ev; + int data, fd; + + ASSERT_THAT(epoll_wait(epollfd, &ev, 1, -1), SyscallSucceedsWithValue(1)); + + fd = ev.data.fd; + EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceedsWithValue(sizeof(data))); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); + auto prev_port = portToFD.find(port); + // Check that all packets from one client have been delivered to the same + // server socket. + if (prev_port == portToFD.end()) { + portToFD[port] = ev.data.fd; + } else { + EXPECT_EQ(portToFD[port], ev.data.fd); + } + } +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetReusePortTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index bfa7943b1..592448289 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,14 +24,13 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { -TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -40,7 +39,7 @@ TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -49,7 +48,7 @@ TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; diff --git a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc index de63f79d9..f178f1af9 100644 --- a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc @@ -22,7 +22,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 044394ba7..2a4ed04a5 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -24,7 +24,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index fa9a9df6f..b02872308 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -23,7 +23,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc index 3a068aacf..3c3712b50 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc @@ -23,7 +23,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 67d29af0a..b828b6844 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -21,7 +21,6 @@ #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 8b8993d3d..98ae414f3 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -27,7 +27,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_non_blocking.cc b/test/syscalls/linux/socket_non_blocking.cc index 73e6dc618..c3520cadd 100644 --- a/test/syscalls/linux/socket_non_blocking.cc +++ b/test/syscalls/linux/socket_non_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc index 3c599b6e8..d91c5ed39 100644 --- a/test/syscalls/linux/socket_non_stream.cc +++ b/test/syscalls/linux/socket_non_stream.cc @@ -19,7 +19,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc index 76127d181..62d87c1af 100644 --- a/test/syscalls/linux/socket_non_stream_blocking.cc +++ b/test/syscalls/linux/socket_non_stream_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc index 0417dd347..346443f96 100644 --- a/test/syscalls/linux/socket_stream.cc +++ b/test/syscalls/linux/socket_stream.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index 8367460d2..e9cc082bf 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream_nonblock.cc b/test/syscalls/linux/socket_stream_nonblock.cc index b00748b97..74d608741 100644 --- a/test/syscalls/linux/socket_stream_nonblock.cc +++ b/test/syscalls/linux/socket_stream_nonblock.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 70710195c..be38907c2 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -30,7 +30,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_format.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 875f0391f..8a28202a8 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -25,7 +25,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc index 1092e29b1..1159c5229 100644 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -25,7 +25,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc index 3e0f611d2..3245cf7c9 100644 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ b/test/syscalls/linux/socket_unix_dgram.cc @@ -17,7 +17,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc index 707052af8..cd4fba25c 100644 --- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc +++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index b5c82cd67..276a94eb8 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -19,7 +19,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/memory_util.h" diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc index 6f6367dd5..60fa9e38a 100644 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_seqpacket.cc @@ -17,7 +17,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 8f38ed92f..563467365 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc index 4b5832de8..7f5816ace 100644 --- a/test/syscalls/linux/socket_unix_unbound_abstract.cc +++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_dgram.cc b/test/syscalls/linux/socket_unix_unbound_dgram.cc index 52aef891f..907dca0f1 100644 --- a/test/syscalls/linux/socket_unix_unbound_dgram.cc +++ b/test/syscalls/linux/socket_unix_unbound_dgram.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc index 8cb03c450..b14f24086 100644 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc index 0575f2e1d..50ffa1d04 100644 --- a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc index e483d2777..344918c34 100644 --- a/test/syscalls/linux/socket_unix_unbound_stream.cc +++ b/test/syscalls/linux/socket_unix_unbound_stream.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 88ab90b5b..30de2f8ff 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -24,7 +24,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index bfa031bce..277d6835a 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -394,8 +394,15 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { sizeof(tcp_nodelay_flag)), SyscallSucceeds()); + // Set a 256KB send/receive buffer. + int buf_sz = 1 << 18; + EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + // Create a large buffer that will be used for sending. - std::vector<char> buf(10 * sendbuf_size_); + std::vector<char> buf(1 << 16); // Write until we receive an error. while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) { diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h index c413d63ea..61526b4e7 100644 --- a/test/util/multiprocess_util.h +++ b/test/util/multiprocess_util.h @@ -109,6 +109,15 @@ PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, const std::string& pathname const std::function<void()>& fn, pid_t* child, int* execve_errno); +inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, + const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, int flags, + pid_t* child, int* execve_errno) { + return ForkAndExecveat( + dirfd, pathname, argv, envv, flags, [] {}, child, execve_errno); +} + // Calls fn in a forked subprocess and returns the exit status of the // subprocess. // diff --git a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go b/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go index 8baec5458..3b9346843 100644 --- a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go +++ b/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/third_party/gvsync/downgradable_rwmutex_unsafe.go index 1f6007aa1..b7862d185 100644 --- a/third_party/gvsync/downgradable_rwmutex_unsafe.go +++ b/third_party/gvsync/downgradable_rwmutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/third_party/gvsync/memmove_unsafe.go b/third_party/gvsync/memmove_unsafe.go index 84b69f215..9dd1d6142 100644 --- a/third_party/gvsync/memmove_unsafe.go +++ b/third_party/gvsync/memmove_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. |