diff options
40 files changed, 1135 insertions, 241 deletions
@@ -133,11 +133,9 @@ The [gvisor-users mailing list][gvisor-users-list] and [gvisor-dev mailing list][gvisor-dev-list] are good starting points for questions and discussion. -## Security +## Security Policy -Sensitive security-related questions, comments and disclosures can be sent to -the [gvisor-security mailing list][gvisor-security-list]. The full security -disclosure policy is defined in the [community][community] repository. +See [SECURITY.md](SECURITY.md). ## Contributing @@ -147,7 +145,6 @@ See [Contributing.md](CONTRIBUTING.md). [community]: https://gvisor.googlesource.com/community [docker]: https://www.docker.com [git]: https://git-scm.com -[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security [gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users [gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev [oci]: https://www.opencontainers.org diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..154d68cb3 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,11 @@ +# Security and Vulnerability Reporting + +Sensitive security-related questions, comments, and reports should be sent to +the [gvisor-security mailing list][gvisor-security-list]. You should receive a +prompt response, typically within 48 hours. + +Policies for security list access, vulnerability embargo, and vulnerability +disclosure are outlined in the [community][community] repository. + +[community]: https://gvisor.googlesource.com/community +[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 7d742871a..257f67222 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -271,7 +271,7 @@ type Statx struct { } // FileMode represents a mode_t. -type FileMode uint +type FileMode uint16 // Permissions returns just the permission bits. func (m FileMode) Permissions() FileMode { diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 51967b811..1b5dac99a 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -11,8 +11,9 @@ go_library( "bits.go", "bits32.go", "bits64.go", - "uint64_arch_amd64.go", + "uint64_arch.go", "uint64_arch_amd64_asm.s", + "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], importpath = "gvisor.dev/gvisor/pkg/bits", diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch.go index faccaa61a..9f23eff77 100644 --- a/pkg/bits/uint64_arch_amd64.go +++ b/pkg/bits/uint64_arch.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package bits diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s new file mode 100644 index 000000000..814ba562d --- /dev/null +++ b/pkg/bits/uint64_arch_arm64_asm.s @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +TEXT ·TrailingZeros64(SB),$0-16 + MOVD x+0(FP), R0 + RBIT R0, R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD R0, ret+8(FP) + RET + +TEXT ·MostSignificantOne64(SB),$0-16 + MOVD x+0(FP), R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD $63, R1 + SUBS R0, R1, R0 // ret = 63 - CLZ + BPL end + MOVD $64, R0 // x == 0 +end: + MOVD R0, ret+8(FP) + RET diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go index 7dd2d1480..9dd2098d1 100644 --- a/pkg/bits/uint64_arch_generic.go +++ b/pkg/bits/uint64_arch_generic.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package bits diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 2412aa5e1..221516c6c 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -505,12 +505,27 @@ func (c *Client) sendRecvChannel(t message, r message) error { ch.active = false c.channelsMu.Unlock() c.channelsWg.Done() - return err + // Map all transport errors to EIO, but ensure that the real error + // is logged. + log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) + return syscall.EIO } } - // Send the message. - err := ch.sendRecv(c, t, r) + // Send the request and receive the server's response. + rsz, err := ch.send(t) + if err != nil { + // See above. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) + return syscall.EIO + } + + // Parse the server's response. + _, retErr := ch.recv(r, rsz) // Release the channel. c.channelsMu.Lock() @@ -523,7 +538,7 @@ func (c *Client) sendRecvChannel(t message, r message) error { c.channelsMu.Unlock() c.channelsWg.Done() - return err + return retErr } // Version returns the negotiated 9P2000.L.Google version number. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 69c886a5d..e717e6161 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -452,7 +452,9 @@ func (cs *connState) initializeChannels() (err error) { cs.channelWg.Add(1) go func() { // S/R-SAFE: Server side. defer cs.channelWg.Done() - res.service(cs) + if err := res.service(cs); err != nil { + log.Warningf("p9.channel.service: %v", err) + } }() } diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index 7cdf4ecc3..233f825e3 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -132,7 +132,7 @@ func (ch *channel) send(m message) (uint32, error) { if filer, ok := m.(filer); ok { if f := filer.FilePayload(); f != nil { if err := ch.fds.SendFD(f.FD()); err != nil { - return 0, syscall.EIO // Map everything to EIO. + return 0, err } f.Close() // Per sendRecvLegacy. sentFD = true // To mark below. @@ -162,15 +162,7 @@ func (ch *channel) send(m message) (uint32, error) { } // Perform the one-shot communication. - n, err := ch.data.SendRecv(ssz) - if err != nil { - if n > 0 { - return n, nil - } - return 0, syscall.EIO // See above. - } - - return n, nil + return ch.data.SendRecv(ssz) } // recv decodes a message that exists on the channel. @@ -249,15 +241,3 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { return r, nil } - -// sendRecv sends the given message over the channel. -// -// This is used by the client. -func (ch *channel) sendRecv(c *Client, m, r message) error { - rsz, err := ch.send(m) - if err != nil { - return err - } - _, err = ch.recv(r, rsz) - return err -} diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 246b97161..5a388dad1 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "fmt" "strings" "gvisor.dev/gvisor/pkg/abi/linux" @@ -207,6 +208,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) { + // Sanity check. + if parent.Inode.overlay == nil { + panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations)) + } + // Dirent.Create takes renameMu if the Inode is an overlay Inode. if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 5e28982c5..f70239449 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -66,7 +66,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo if s.SupportsIPv6() { contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) - contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) } } @@ -310,44 +310,51 @@ func networkToHost16(n uint16) uint16 { return usermem.ByteOrder.Uint16(buf[:]) } -func writeInetAddr(w io.Writer, a linux.SockAddrInet) { - // linux.SockAddrInet.Port is stored in the network byte order and is - // printed like a number in host byte order. Note that all numbers in host - // byte order are printed with the most-significant byte first when - // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. - port := networkToHost16(a.Port) - - // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order - // (i.e. most-significant byte in index 0). Linux represents this as a - // __be32 which is a typedef for an unsigned int, and is printed with - // %X. This means that for a little-endian machine, Linux prints the - // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, - // which makes it have the equivalent encoding to a __be32 on a little - // endian machine. Note that this operation is a no-op on a big endian - // machine. Then similar to Linux, we format it with %X, which will print - // the most-significant byte of the __be32 address first, which is now - // actually the least-significant byte of the original address in - // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) - - fmt.Fprintf(w, "%08X:%04X ", addr, port) -} +func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { + switch family { + case linux.AF_INET: + var a linux.SockAddrInet + if i != nil { + a = *i.(*linux.SockAddrInet) + } -// netTCP implements seqfile.SeqSource for /proc/net/tcp. -// -// +stateify savable -type netTCP struct { - k *kernel.Kernel -} + // linux.SockAddrInet.Port is stored in the network byte order and is + // printed like a number in host byte order. Note that all numbers in host + // byte order are printed with the most-significant byte first when + // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. + port := networkToHost16(a.Port) + + // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order + // (i.e. most-significant byte in index 0). Linux represents this as a + // __be32 which is a typedef for an unsigned int, and is printed with + // %X. This means that for a little-endian machine, Linux prints the + // least-significant byte of the address first. To emulate this, we first + // invert the byte order for the address using usermem.ByteOrder.Uint32, + // which makes it have the equivalent encoding to a __be32 on a little + // endian machine. Note that this operation is a no-op on a big endian + // machine. Then similar to Linux, we format it with %X, which will print + // the most-significant byte of the __be32 address first, which is now + // actually the least-significant byte of the original address in + // linux.SockAddrInet.Addr on little endian machines, due to the conversion. + addr := usermem.ByteOrder.Uint32(a.Addr[:]) + + fmt.Fprintf(w, "%08X:%04X ", addr, port) + case linux.AF_INET6: + var a linux.SockAddrInet6 + if i != nil { + a = *i.(*linux.SockAddrInet6) + } -// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. -func (*netTCP) NeedsUpdate(generation int64) bool { - return true + port := networkToHost16(a.Port) + addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) + } } -// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. -func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { +func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) { // t may be nil here if our caller is not part of a task goroutine. This can // happen for example if we're here for "sentryctl cat". When t is nil, // degrade gracefully and retrieve what we can. @@ -358,7 +365,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se } var buf bytes.Buffer - for _, se := range n.k.ListSockets() { + for _, se := range k.ListSockets() { s := se.Sock.Get() if s == nil { log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) @@ -369,7 +376,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se if !ok { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } - if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) { + if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { s.DecRef() // Not tcp4 sockets. continue @@ -384,22 +391,22 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se fmt.Fprintf(&buf, "%4d: ", se.ID) // Field: local_adddress. - var localAddr linux.SockAddrInet + var localAddr linux.SockAddr if t != nil { if local, _, err := sops.GetSockName(t); err == nil { - localAddr = *local.(*linux.SockAddrInet) + localAddr = local } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, fa, localAddr) // Field: rem_address. - var remoteAddr linux.SockAddrInet + var remoteAddr linux.SockAddr if t != nil { if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = *remote.(*linux.SockAddrInet) + remoteAddr = remote } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, fa, remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) @@ -465,7 +472,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se data := []seqfile.SeqData{ { - Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n"), + Buf: header, Handle: n, }, { @@ -476,6 +483,42 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netTCP implements seqfile.SeqSource for /proc/net/tcp. +// +// +stateify savable +type netTCP struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header) +} + +// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6. +// +// +stateify savable +type netTCP6 struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP6) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header) +} + // netUDP implements seqfile.SeqSource for /proc/net/udp. // // +stateify savable @@ -529,7 +572,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se localAddr = *local.(*linux.SockAddrInet) } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, linux.AF_INET, &localAddr) // Field: rem_address. var remoteAddr linux.SockAddrInet @@ -538,7 +581,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se remoteAddr = *remote.(*linux.SockAddrInet) } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, linux.AF_INET, &remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 0ef13f2f5..56e92721e 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // But for whatever crazy reason, you can still walk to the given node. for _, tg := range rpf.iops.pidns.ThreadGroups() { if leader := tg.Leader(); leader != nil { - name := strconv.FormatUint(uint64(tg.ID()), 10) + name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10) m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice) names = append(names, name) } diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c620227c9..0bd82e480 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -32,7 +32,7 @@ type directory struct { childList dentryList } -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode { dir := &directory{} dir.inode.init(dir, fs, creds, mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index 45cd42b3e..b78471c0f 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -137,7 +137,7 @@ type inode struct { impl interface{} // immutable } -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) { +func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go index 55f869798..b7f4853b3 100644 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ b/pkg/sentry/fsimpl/memfs/regular_file.go @@ -37,7 +37,7 @@ type regularFile struct { dataLen int64 } -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { file := ®ularFile{} file.inode.init(file, fs, creds, mode) file.inode.nlink = 1 // from parent directory diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index e76ee27d2..cf2a56bed 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -8,6 +8,8 @@ go_library( "error.go", "flags.go", "linux64.go", + "linux64_amd64.go", + "linux64_arm64.go", "sigset.go", "sys_aio.go", "sys_capability.go", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 72c383537..b64c49ff5 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,4 +19,4 @@ const ( _LINUX_SYSNAME = "Linux" _LINUX_RELEASE = "4.4" _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016" -)
\ No newline at end of file +) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index e6f07595c..e215ac049 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux provides syscall tables for amd64 Linux. package linux import ( @@ -228,7 +227,7 @@ var AMD64 = &kernel.SyscallTable{ 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), - 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) + 187: syscalls.Supported("readahead", Readahead), 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), @@ -316,21 +315,21 @@ var AMD64 = &kernel.SyscallTable{ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) 280: syscalls.Supported("utimensat", Utimensat), 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 283: syscalls.Supported("timerfd_create", TimerfdCreate), 284: syscalls.Supported("eventfd", Eventfd), 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), 286: syscalls.Supported("timerfd_settime", TimerfdSettime), 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 290: syscalls.Supported("eventfd2", Eventfd2), 291: syscalls.Supported("epoll_create1", EpollCreate1), 292: syscalls.Supported("dup3", Dup3), diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index 57c36bc4a..1d3b63020 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux provides syscall tables for arm64 Linux. package linux import ( diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 9f705ebca..dd3a5807f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -159,9 +159,14 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc }, outFile.Flags().NonBlocking) } + // Sendfile can't lose any data because inFD is always a regual file. + if n != 0 { + err = nil + } + // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -305,6 +310,11 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo Dup: true, }, nonBlock) + // Tee doesn't change a state of inFD, so it can't lose any data. + if n != 0 { + err = nil + } + // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 187e5410c..3aa73d911 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -31,14 +31,14 @@ type GetDentryOptions struct { // FilesystemImpl.MkdirAt(). type MkdirOptions struct { // Mode is the file mode bits for the created directory. - Mode uint16 + Mode linux.FileMode } // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). type MknodOptions struct { // Mode is the file type and mode bits for the created file. - Mode uint16 + Mode linux.FileMode // If Mode specifies a character or block device special file, DevMajor and // DevMinor are the major and minor device numbers for the created device. @@ -61,7 +61,7 @@ type OpenOptions struct { // If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the // created file. - Mode uint16 + Mode linux.FileMode } // ReadOptions contains options to FileDescription.PRead(), diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b6641ccc3..a53894c01 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -47,9 +47,6 @@ func TestExcludeBroadcast(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { t.Fatalf("AddAddress failed: %v", err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f6106f762..5993fe582 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -104,6 +104,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback func (n *NIC) enable() *tcpip.Error { n.attachLinkEndpoint() + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -372,7 +382,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { - // Don't include expired or tempory endpoints to avoid confusion and + // Don't include expired or temporary endpoints to avoid confusion and // prevent the caller from using those. switch ref.getKind() { case permanentExpired, temporary: @@ -624,21 +634,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.AddLinkAddress(n.id, src, remote) - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) - } - } - n.mu.RUnlock() - return - } - if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5c8b7977a..0b09e6517 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -208,10 +210,17 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 90c2cf1be..ff574a055 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -902,7 +902,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2dede8a9..a2e0a6e7b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -952,10 +952,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -967,10 +967,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -1016,17 +1016,33 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) @@ -1039,28 +1055,99 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } +} + +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + } + + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, + } + s.SetRouteTable(rt) + + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 8c768c299..92267ce4d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -63,7 +63,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, vv) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return @@ -338,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } - } + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a broadcast or multicast, then find all matching + // transport endpoints. + var destEps []*endpointsByNic + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, vv, id) } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } @@ -426,10 +417,11 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -437,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -445,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -491,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index faaa4a4e3..70e7575f5 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -57,6 +57,9 @@ type Error struct { // String implements fmt.Stringer.String. func (e *Error) String() string { + if e == nil { + return "<nil>" + } return e.msg } @@ -1095,6 +1098,47 @@ func (a AddressWithPrefix) String() string { return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) } +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := len(a.Address) + if a.PrefixLen <= 0 { + return Subnet{ + address: Address(strings.Repeat("\x00", addrLen)), + mask: AddressMask(strings.Repeat("\x00", addrLen)), + } + } + if a.PrefixLen >= addrLen*8 { + return Subnet{ + address: a.Address, + mask: AddressMask(strings.Repeat("\xff", addrLen)), + } + } + + sa := make([]byte, addrLen) + sm := make([]byte, addrLen) + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa[i] = a.Address[i] + sm[i] = 0xff + n -= 8 + continue + } + sm[i] = ^byte(0xff >> n) + sa[i] = a.Address[i] & sm[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(Address(sa), AddressMask(sm)) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index fb3a0a5ee..8c0aacffa 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) { t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) } } + +func TestAddressWithPrefixSubnet(t *testing.T) { + tests := []struct { + addr Address + prefixLen int + subnetAddr Address + subnetMask AddressMask + }{ + {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, + {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, + {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + } + for _, tt := range tests { + ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} + gotSubnet := ap.Subnet() + wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) + if err != nil { + t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + continue + } + if gotSubnet != wantSubnet { + t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) + } + } +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 5059ca22d..d1ba8d2ce 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -532,10 +532,11 @@ func TestBindToDeviceOption(t *testing.T) { } } -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) { c.t.Helper() payload := newPayload() @@ -553,27 +554,51 @@ func testRead(c *testContext, flow testFlow) { select { case <-ch: v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") } } + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + c.t.Fatalf("bad payload: got %x, want %x", v, payload) } } +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */) +} + func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -743,13 +768,17 @@ func TestReadOnBoundToMulticast(t *testing.T) { c.t.Fatal("SetSockOpt failed:", err) } + // Check that we receive multicast packets but not unicast or broadcast + // ones. testRead(c, flow) + testFailingRead(c, broadcast) + testFailingRead(c, unicastV4) }) } } // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and receive broadcast data on it. +// address and can receive only broadcast data. func TestV4ReadOnBoundToBroadcast(t *testing.T) { for _, flow := range []testFlow{broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -764,8 +793,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - // Test acceptance. + // Check that we receive broadcast packets but not unicast ones. + testRead(c, flow) + testFailingRead(c, unicastV4) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. testRead(c, flow) + testRead(c, unicastV4) }) } } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 341e6b252..5a3c0896c 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -185,7 +185,7 @@ syscall_test( ) syscall_test( - size = "medium", + size = "large", shard_count = 5, test = "//test/syscalls/linux:itimer_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index d5a2b7725..b9fee9ee9 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -333,6 +333,7 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "@com_google_googletest//:gtest", @@ -1904,6 +1905,7 @@ cc_binary( srcs = ["sendfile.cc"], linkstatic = 1, deps = [ + "//test/util:eventfd_util", "//test/util:file_descriptor", "//test/util:temp_path", "//test/util:test_main", diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 410b42a47..f81ed22c8 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -60,6 +60,14 @@ SocketPairKind IPv6TCPAcceptBindSocketPair(int type) { /* dual_stack = */ false)}; } +SocketKind IPv6TCPUnboundSocket(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "IPv6 TCP socket"); + return SocketKind{ + description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, + UnboundSocketCreator(AF_INET6, type | SOCK_STREAM, IPPROTO_TCP)}; +} + SocketPairKind IPv4TCPAcceptBindSocketPair(int type) { std::string description = absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket"); diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 3d36b9620..a7825770e 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -96,6 +96,10 @@ SocketKind IPv4UDPUnboundSocket(int type); // a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); +// IPv6TCPUnboundSocketPair returns a SocketKind that represents +// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +SocketKind IPv6TCPUnboundSocket(int type); + // IfAddrHelper is a helper class that determines the local interfaces present // and provides functions to obtain their names, index numbers, and IP address. class IfAddrHelper { diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index f6d7ad0bb..f61795592 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -250,6 +250,247 @@ TEST(ProcNetTCP, State) { EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED); } +constexpr char kProcNetTCP6Header[] = + " sl local_address remote_address" + " st tx_queue rx_queue tr tm->when retrnsmt" + " uid timeout inode"; + +// TCP6Entry represents a single entry from /proc/net/tcp6. +struct TCP6Entry { + struct in6_addr local_addr; + uint16_t local_port; + + struct in6_addr remote_addr; + uint16_t remote_port; + + uint64_t state; + uint64_t uid; + uint64_t inode; +}; + +bool IPv6AddrEqual(const struct in6_addr* a1, const struct in6_addr* a2) { + return memcmp(a1, a2, sizeof(struct in6_addr)) == 0; +} + +// Finds the first entry in 'entries' for which 'predicate' returns true. +// Returns true on match, and sets 'match' to a copy of the matching entry. If +// 'match' is null, it's ignored. +bool FindBy6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, + std::function<bool(const TCP6Entry&)> predicate) { + for (const TCP6Entry& entry : entries) { + if (predicate(entry)) { + if (match != nullptr) { + *match = entry; + } + return true; + } + } + return false; +} + +const struct in6_addr* IP6FromInetSockaddr(const struct sockaddr* addr) { + auto* addr6 = reinterpret_cast<const struct sockaddr_in6*>(addr); + return &addr6->sin6_addr; +} + +bool FindByLocalAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, + const struct sockaddr* addr) { + const struct in6_addr* local = IP6FromInetSockaddr(addr); + uint16_t port = PortFromInetSockaddr(addr); + return FindBy6(entries, match, [local, port](const TCP6Entry& e) { + return (IPv6AddrEqual(&e.local_addr, local) && e.local_port == port); + }); +} + +bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, + const struct sockaddr* addr) { + const struct in6_addr* remote = IP6FromInetSockaddr(addr); + uint16_t port = PortFromInetSockaddr(addr); + return FindBy6(entries, match, [remote, port](const TCP6Entry& e) { + return (IPv6AddrEqual(&e.remote_addr, remote) && e.remote_port == port); + }); +} + +void ReadIPv6Address(std::string s, struct in6_addr* addr) { + uint32_t a0, a1, a2, a3; + const char* fmt = "%08X%08X%08X%08X"; + EXPECT_EQ(sscanf(s.c_str(), fmt, &a0, &a1, &a2, &a3), 4); + + uint8_t* b = addr->s6_addr; + *((uint32_t*)&b[0]) = a0; + *((uint32_t*)&b[4]) = a1; + *((uint32_t*)&b[8]) = a2; + *((uint32_t*)&b[12]) = a3; +} + +// Returns a parsed representation of /proc/net/tcp6 entries. +PosixErrorOr<std::vector<TCP6Entry>> ProcNetTCP6Entries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/net/tcp6", &content)); + + bool found_header = false; + std::vector<TCP6Entry> entries; + std::vector<std::string> lines = StrSplit(content, '\n'); + std::cerr << "<contents of /proc/net/tcp6>" << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + + if (!found_header) { + EXPECT_EQ(line, kProcNetTCP6Header); + found_header = true; + continue; + } + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/net/tcp6. + // + // Example entries: + // + // clang-format off + // + // sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + // 0: 00000000000000000000000000000000:1F90 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876340 1 ffff8803da9c9380 100 0 0 10 0 + // 1: 00000000000000000000000000000000:C350 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876987 1 ffff8803ec408000 100 0 0 10 0 + // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 + // + // clang-format on + + TCP6Entry entry; + std::vector<std::string> fields = + StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty()); + + ReadIPv6Address(fields[1], &entry.local_addr); + ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16)); + ReadIPv6Address(fields[3], &entry.remote_addr); + ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11])); + ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13])); + + entries.push_back(entry); + } + std::cerr << "<end of /proc/net/tcp6>" << std::endl; + + return entries; +} + +TEST(ProcNetTCP6, Exists) { + const std::string content = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp6")); + const std::string header_line = StrCat(kProcNetTCP6Header, "\n"); + if (IsRunningOnGvisor()) { + // Should be just the header since we don't have any tcp sockets yet. + EXPECT_EQ(content, header_line); + } else { + // On a general linux machine, we could have abitrary sockets on the system, + // so just check the header. + EXPECT_THAT(content, ::testing::StartsWith(header_line)); + } +} + +TEST(ProcNetTCP6, EntryUID) { + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + TCP6Entry e; + + ASSERT_TRUE(FindByLocalAddr6(entries, &e, sockets->first_addr())); + EXPECT_EQ(e.uid, geteuid()); + ASSERT_TRUE(FindByRemoteAddr6(entries, &e, sockets->first_addr())); + EXPECT_EQ(e.uid, geteuid()); +} + +TEST(ProcNetTCP6, BindAcceptConnect) { + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + // We can only make assertions about the total number of entries if we control + // the entire "machine". + if (IsRunningOnGvisor()) { + EXPECT_EQ(entries.size(), 2); + } + + EXPECT_TRUE(FindByLocalAddr6(entries, nullptr, sockets->first_addr())); + EXPECT_TRUE(FindByRemoteAddr6(entries, nullptr, sockets->first_addr())); +} + +TEST(ProcNetTCP6, InodeReasonable) { + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + + TCP6Entry accepted_entry; + + ASSERT_TRUE( + FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr())); + EXPECT_NE(accepted_entry.inode, 0); + + TCP6Entry client_entry; + ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, sockets->first_addr())); + EXPECT_NE(client_entry.inode, 0); + EXPECT_NE(accepted_entry.inode, client_entry.inode); +} + +TEST(ProcNetTCP6, State) { + std::unique_ptr<FileDescriptor> server = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create()); + + auto test_addr = V6Loopback(); + ASSERT_THAT( + bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr), + test_addr.addr_len), + SyscallSucceeds()); + + struct sockaddr_in6 addr6; + socklen_t addrlen = sizeof(struct sockaddr_in6); + auto* addr = reinterpret_cast<struct sockaddr*>(&addr6); + ASSERT_THAT(getsockname(server->get(), addr, &addrlen), SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + TCP6Entry listen_entry; + + ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr)); + EXPECT_EQ(listen_entry.state, TCP_LISTEN); + + std::unique_ptr<FileDescriptor> client = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create()); + ASSERT_THAT(RetryEINTR(connect)(client->get(), addr, addrlen), + SyscallSucceeds()); + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr)); + EXPECT_EQ(listen_entry.state, TCP_LISTEN); + TCP6Entry client_entry; + ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, addr)); + EXPECT_EQ(client_entry.state, TCP_ESTABLISHED); + + FileDescriptor accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr)); + + const struct in6_addr* local = IP6FromInetSockaddr(addr); + const uint16_t accepted_local_port = PortFromInetSockaddr(addr); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + TCP6Entry accepted_entry; + ASSERT_TRUE(FindBy6( + entries, &accepted_entry, + [client_entry, local, accepted_local_port](const TCP6Entry& e) { + return IPv6AddrEqual(&e.local_addr, local) && + e.local_port == accepted_local_port && + IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) && + e.remote_port == client_entry.local_port; + })); + EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 4502e7fb4..580ab5193 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -21,6 +22,7 @@ #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -511,6 +513,23 @@ TEST(SendFileTest, SendPipeBlocks) { SyscallSucceedsWithValue(kDataSize)); } +TEST(SendFileTest, SendToSpecialFile) { + // Create temp file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); + + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); + constexpr int kSize = 0x7ff; + ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds()); + + auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); + + // eventfd can accept a number of bytes which is a multiple of 8. + EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff), + SyscallSucceedsWithValue(kSize & (~7))); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index caae215b8..3a07ac8d2 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" namespace gvisor { @@ -57,5 +58,28 @@ TEST(SocketTest, ProtocolInet) { } } +using SocketOpenTest = ::testing::TestWithParam<int>; + +// UDS cannot be opened. +TEST_P(SocketOpenTest, Unix) { + // FIXME(b/142001530): Open incorrectly succeeds on gVisor. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX)); + + ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + EXPECT_THAT(open(addr.sun_path, GetParam()), SyscallFailsWithErrno(ENXIO)); +} + +INSTANTIATE_TEST_SUITE_P(OpenModes, SocketOpenTest, + ::testing::Values(O_RDONLY, O_RDWR)); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index c85ae30dc..8b8993d3d 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -42,6 +42,26 @@ TestAddress V4EmptyAddress() { return t; } +constexpr char kMulticastAddress[] = "224.0.2.1"; + +TestAddress V4Multicast() { + TestAddress t("V4Multicast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kMulticastAddress); + return t; +} + +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + htonl(INADDR_BROADCAST); + return t; +} + void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { got_if_infos_ = false; @@ -116,7 +136,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { // Verifies that a broadcast UDP packet will arrive at all UDP sockets with // the destination port number. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastReceivedOnAllExpectedEndpoints) { + UDPBroadcastReceivedOnExpectedPort) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -136,51 +156,134 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); - sockaddr_in rcv_addr = {}; - socklen_t rcv_addr_sz = sizeof(rcv_addr); - rcv_addr.sin_family = AF_INET; - rcv_addr.sin_addr.s_addr = htonl(INADDR_ANY); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr), - rcv_addr_sz), + // Bind the first socket to the ANY address and let the system assign a port. + auto rcv1_addr = V4Any(); + ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), + rcv1_addr.addr_len), SyscallSucceedsWithValue(0)); // Retrieve port number from first socket so that it can be bound to the // second socket. - rcv_addr = {}; + socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr), + getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), &rcv_addr_sz), SyscallSucceedsWithValue(0)); - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr), + EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); + auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; + + // Bind the second socket to the same address:port as the first. + ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), rcv_addr_sz), SyscallSucceedsWithValue(0)); // Bind the non-receiving socket to an ephemeral port. - sockaddr_in norcv_addr = {}; - norcv_addr.sin_family = AF_INET; - norcv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + auto norecv_addr = V4Any(); + ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), + norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); + + // Broadcast a test message. + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; + constexpr char kTestMsg[] = "hello, world"; + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the receiving sockets received the test message. + char buf[sizeof(kTestMsg)] = {}; + EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); + memset(buf, 0, sizeof(buf)); + EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); + + // Verify that the non-receiving socket did not receive the test message. + memset(buf, 0, sizeof(buf)); + EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Verifies that a broadcast UDP packet will arrive at all UDP sockets bound to +// the destination port number and either INADDR_ANY or INADDR_BROADCAST, but +// not a unicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + UDPBroadcastReceivedOnExpectedAddresses) { + // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its + // IPv4 address on eth0. + SKIP_IF(!got_if_infos_); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Enable SO_BROADCAST on the sending socket. + ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Enable SO_REUSEPORT on all sockets so that they may all be bound to the + // broadcast messages destination port. + ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(setsockopt(norcv->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Bind the first socket the ANY address and let the system assign a port. + auto rcv1_addr = V4Any(); + ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), + rcv1_addr.addr_len), + SyscallSucceedsWithValue(0)); + // Retrieve port number from first socket so that it can be bound to the + // second socket. + socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - bind(norcv->get(), reinterpret_cast<struct sockaddr*>(&norcv_addr), - sizeof(norcv_addr)), + getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), + &rcv_addr_sz), SyscallSucceedsWithValue(0)); + EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); + auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; + + // Bind the second socket to the broadcast address. + auto rcv2_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port; + ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr), + rcv2_addr.addr_len), + SyscallSucceedsWithValue(0)); + + // Bind the non-receiving socket to the unicast ethernet address. + auto norecv_addr = rcv1_addr; + reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = + eth_if_sin_addr_; + ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), + norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); // Broadcast a test message. - sockaddr_in dst_addr = {}; - dst_addr.sin_family = AF_INET; - dst_addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); - dst_addr.sin_port = rcv_addr.sin_port; + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT( sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&dst_addr), sizeof(dst_addr)), + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the receiving sockets received the test message. char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(read(rcvr1->get(), buf, sizeof(buf)), + EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); memset(buf, 0, sizeof(buf)); - EXPECT_THAT(read(rcvr2->get(), buf, sizeof(buf)), + EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); @@ -190,10 +293,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SyscallFailsWithErrno(EAGAIN)); } -// Verifies that a UDP broadcast sent via the loopback interface is not received -// by the sender. +// Verifies that a UDP broadcast can be sent and then received back on the same +// socket that is bound to the broadcast address (255.255.255.255). +// FIXME(b/141938460): This can be combined with the next test +// (UDPBroadcastSendRecvOnSocketBoundToAny). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastViaLoopbackFails) { + UDPBroadcastSendRecvOnSocketBoundToBroadcast) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -201,33 +306,73 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); - // Bind the sender to the loopback interface. - sockaddr_in src = {}; - socklen_t src_sz = sizeof(src); - src.sin_family = AF_INET; - src.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT( - bind(sender->get(), reinterpret_cast<struct sockaddr*>(&src), src_sz), - SyscallSucceedsWithValue(0)); + // Bind the sender to the broadcast address. + auto src_addr = V4Broadcast(); + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), + src_addr.addr_len), + SyscallSucceedsWithValue(0)); + socklen_t src_sz = src_addr.addr_len; ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<struct sockaddr*>(&src), &src_sz), + reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), SyscallSucceedsWithValue(0)); - ASSERT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(src_sz, src_addr.addr_len); // Send the message. - sockaddr_in dst = {}; - dst.sin_family = AF_INET; - dst.sin_addr.s_addr = htonl(INADDR_BROADCAST); - dst.sin_port = src.sin_port; + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&dst), sizeof(dst)), + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the message was received. + char buf[sizeof(kTestMsg)] = {}; + EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); +} + +// Verifies that a UDP broadcast can be sent and then received back on the same +// socket that is bound to the ANY address (0.0.0.0). +// FIXME(b/141938460): This can be combined with the previous test +// (UDPBroadcastSendRecvOnSocketBoundToBroadcast). +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + UDPBroadcastSendRecvOnSocketBoundToAny) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - // Verify that the message was not received by the sender (loopback). + // Enable SO_BROADCAST. + ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Bind the sender to the ANY address. + auto src_addr = V4Any(); + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), + src_addr.addr_len), + SyscallSucceedsWithValue(0)); + socklen_t src_sz = src_addr.addr_len; + ASSERT_THAT(getsockname(sender->get(), + reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(src_sz, src_addr.addr_len); + + // Send the message. + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; + constexpr char kTestMsg[] = "hello, world"; + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the message was received. char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); } // Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST @@ -237,15 +382,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { // Broadcast a test message without having enabled SO_BROADCAST on the sending // socket. - sockaddr_in addr = {}; - socklen_t addr_sz = sizeof(addr); - addr.sin_family = AF_INET; - addr.sin_port = htons(12345); - addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); + auto addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(12345); constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&addr), addr_sz), + reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EACCES)); } @@ -274,21 +416,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) { reinterpret_cast<struct sockaddr*>(&addr), addr_sz), SyscallSucceedsWithValue(sizeof(kTestMsg))); char buf[sizeof(kTestMsg)] = {}; - ASSERT_THAT(read(rcvr->get(), buf, sizeof(buf)), + ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); } -constexpr char kMulticastAddress[] = "224.0.2.1"; - -TestAddress V4Multicast() { - TestAddress t("V4Multicast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - inet_addr(kMulticastAddress); - return t; -} - // Check that multicast packets won't be delivered to the sending socket with no // set interface or group membership. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, @@ -609,8 +740,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, } // Check that two sockets can join the same multicast group at the same time, -// and both will receive data on it. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { +// and both will receive data on it when bound to the ANY address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToAny) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -624,8 +756,72 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - // Bind the receiver to the v4 any address to ensure that we can receive the - // multicast packet. + // Bind to ANY to receive multicast packets. + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + EXPECT_EQ( + htonl(INADDR_ANY), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); + // On the first iteration, save the port we are bound to. On the second + // iteration, verify the port is the same as the one from the first + // iteration. In other words, both sockets listen on the same port. + if (bound_port == 0) { + bound_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + } else { + EXPECT_EQ(bound_port, + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that two sockets can join the same multicast group at the same time, +// and both will receive data on it when bound to the multicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToMulticastAddress) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr<FileDescriptor> receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + auto receiver_addr = V4Multicast(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); ASSERT_THAT( bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), receiver_addr.addr_len), @@ -636,6 +832,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); // On the first iteration, save the port we are bound to. On the second // iteration, verify the port is the same as the one from the first // iteration. In other words, both sockets listen on the same port. @@ -643,6 +842,83 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { bound_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; } else { + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); + EXPECT_EQ(bound_port, + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that two sockets can join the same multicast group at the same time, +// and with one bound to the wildcard address and the other bound to the +// multicast address, both will receive data. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr<FileDescriptor> receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + // The first receiver binds to the wildcard address. + auto receiver_addr = V4Any(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + // On the first iteration, save the port we are bound to and change the + // receiver address from V4Any to V4Multicast so the second receiver binds + // to that. On the second iteration, verify the port is the same as the one + // from the first iteration but the address is different. + if (bound_port == 0) { + EXPECT_EQ( + htonl(INADDR_ANY), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); + bound_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + receiver_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = + bound_port; + } else { + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); EXPECT_EQ(bound_port, reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); } diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 6efa8055f..70710195c 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -83,6 +83,8 @@ inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) { count); } +PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain); + // A Creator<T> is a function that attempts to create and return a new T. (This // is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated // here for clarity.) |