diff options
Diffstat (limited to 'pkg')
29 files changed, 479 insertions, 169 deletions
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 7d742871a..257f67222 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -271,7 +271,7 @@ type Statx struct { } // FileMode represents a mode_t. -type FileMode uint +type FileMode uint16 // Permissions returns just the permission bits. func (m FileMode) Permissions() FileMode { diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 51967b811..1b5dac99a 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -11,8 +11,9 @@ go_library( "bits.go", "bits32.go", "bits64.go", - "uint64_arch_amd64.go", + "uint64_arch.go", "uint64_arch_amd64_asm.s", + "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], importpath = "gvisor.dev/gvisor/pkg/bits", diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch.go index faccaa61a..9f23eff77 100644 --- a/pkg/bits/uint64_arch_amd64.go +++ b/pkg/bits/uint64_arch.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package bits diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s new file mode 100644 index 000000000..814ba562d --- /dev/null +++ b/pkg/bits/uint64_arch_arm64_asm.s @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +TEXT ·TrailingZeros64(SB),$0-16 + MOVD x+0(FP), R0 + RBIT R0, R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD R0, ret+8(FP) + RET + +TEXT ·MostSignificantOne64(SB),$0-16 + MOVD x+0(FP), R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD $63, R1 + SUBS R0, R1, R0 // ret = 63 - CLZ + BPL end + MOVD $64, R0 // x == 0 +end: + MOVD R0, ret+8(FP) + RET diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go index 7dd2d1480..9dd2098d1 100644 --- a/pkg/bits/uint64_arch_generic.go +++ b/pkg/bits/uint64_arch_generic.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package bits diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 2412aa5e1..221516c6c 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -505,12 +505,27 @@ func (c *Client) sendRecvChannel(t message, r message) error { ch.active = false c.channelsMu.Unlock() c.channelsWg.Done() - return err + // Map all transport errors to EIO, but ensure that the real error + // is logged. + log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) + return syscall.EIO } } - // Send the message. - err := ch.sendRecv(c, t, r) + // Send the request and receive the server's response. + rsz, err := ch.send(t) + if err != nil { + // See above. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) + return syscall.EIO + } + + // Parse the server's response. + _, retErr := ch.recv(r, rsz) // Release the channel. c.channelsMu.Lock() @@ -523,7 +538,7 @@ func (c *Client) sendRecvChannel(t message, r message) error { c.channelsMu.Unlock() c.channelsWg.Done() - return err + return retErr } // Version returns the negotiated 9P2000.L.Google version number. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 69c886a5d..e717e6161 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -452,7 +452,9 @@ func (cs *connState) initializeChannels() (err error) { cs.channelWg.Add(1) go func() { // S/R-SAFE: Server side. defer cs.channelWg.Done() - res.service(cs) + if err := res.service(cs); err != nil { + log.Warningf("p9.channel.service: %v", err) + } }() } diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index 7cdf4ecc3..233f825e3 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -132,7 +132,7 @@ func (ch *channel) send(m message) (uint32, error) { if filer, ok := m.(filer); ok { if f := filer.FilePayload(); f != nil { if err := ch.fds.SendFD(f.FD()); err != nil { - return 0, syscall.EIO // Map everything to EIO. + return 0, err } f.Close() // Per sendRecvLegacy. sentFD = true // To mark below. @@ -162,15 +162,7 @@ func (ch *channel) send(m message) (uint32, error) { } // Perform the one-shot communication. - n, err := ch.data.SendRecv(ssz) - if err != nil { - if n > 0 { - return n, nil - } - return 0, syscall.EIO // See above. - } - - return n, nil + return ch.data.SendRecv(ssz) } // recv decodes a message that exists on the channel. @@ -249,15 +241,3 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { return r, nil } - -// sendRecv sends the given message over the channel. -// -// This is used by the client. -func (ch *channel) sendRecv(c *Client, m, r message) error { - rsz, err := ch.send(m) - if err != nil { - return err - } - _, err = ch.recv(r, rsz) - return err -} diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 246b97161..5a388dad1 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "fmt" "strings" "gvisor.dev/gvisor/pkg/abi/linux" @@ -207,6 +208,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) { + // Sanity check. + if parent.Inode.overlay == nil { + panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations)) + } + // Dirent.Create takes renameMu if the Inode is an overlay Inode. if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 5e28982c5..f70239449 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -66,7 +66,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo if s.SupportsIPv6() { contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) - contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) } } @@ -310,44 +310,51 @@ func networkToHost16(n uint16) uint16 { return usermem.ByteOrder.Uint16(buf[:]) } -func writeInetAddr(w io.Writer, a linux.SockAddrInet) { - // linux.SockAddrInet.Port is stored in the network byte order and is - // printed like a number in host byte order. Note that all numbers in host - // byte order are printed with the most-significant byte first when - // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. - port := networkToHost16(a.Port) - - // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order - // (i.e. most-significant byte in index 0). Linux represents this as a - // __be32 which is a typedef for an unsigned int, and is printed with - // %X. This means that for a little-endian machine, Linux prints the - // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, - // which makes it have the equivalent encoding to a __be32 on a little - // endian machine. Note that this operation is a no-op on a big endian - // machine. Then similar to Linux, we format it with %X, which will print - // the most-significant byte of the __be32 address first, which is now - // actually the least-significant byte of the original address in - // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) - - fmt.Fprintf(w, "%08X:%04X ", addr, port) -} +func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { + switch family { + case linux.AF_INET: + var a linux.SockAddrInet + if i != nil { + a = *i.(*linux.SockAddrInet) + } -// netTCP implements seqfile.SeqSource for /proc/net/tcp. -// -// +stateify savable -type netTCP struct { - k *kernel.Kernel -} + // linux.SockAddrInet.Port is stored in the network byte order and is + // printed like a number in host byte order. Note that all numbers in host + // byte order are printed with the most-significant byte first when + // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. + port := networkToHost16(a.Port) + + // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order + // (i.e. most-significant byte in index 0). Linux represents this as a + // __be32 which is a typedef for an unsigned int, and is printed with + // %X. This means that for a little-endian machine, Linux prints the + // least-significant byte of the address first. To emulate this, we first + // invert the byte order for the address using usermem.ByteOrder.Uint32, + // which makes it have the equivalent encoding to a __be32 on a little + // endian machine. Note that this operation is a no-op on a big endian + // machine. Then similar to Linux, we format it with %X, which will print + // the most-significant byte of the __be32 address first, which is now + // actually the least-significant byte of the original address in + // linux.SockAddrInet.Addr on little endian machines, due to the conversion. + addr := usermem.ByteOrder.Uint32(a.Addr[:]) + + fmt.Fprintf(w, "%08X:%04X ", addr, port) + case linux.AF_INET6: + var a linux.SockAddrInet6 + if i != nil { + a = *i.(*linux.SockAddrInet6) + } -// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. -func (*netTCP) NeedsUpdate(generation int64) bool { - return true + port := networkToHost16(a.Port) + addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) + } } -// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. -func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { +func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) { // t may be nil here if our caller is not part of a task goroutine. This can // happen for example if we're here for "sentryctl cat". When t is nil, // degrade gracefully and retrieve what we can. @@ -358,7 +365,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se } var buf bytes.Buffer - for _, se := range n.k.ListSockets() { + for _, se := range k.ListSockets() { s := se.Sock.Get() if s == nil { log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) @@ -369,7 +376,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se if !ok { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } - if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) { + if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { s.DecRef() // Not tcp4 sockets. continue @@ -384,22 +391,22 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se fmt.Fprintf(&buf, "%4d: ", se.ID) // Field: local_adddress. - var localAddr linux.SockAddrInet + var localAddr linux.SockAddr if t != nil { if local, _, err := sops.GetSockName(t); err == nil { - localAddr = *local.(*linux.SockAddrInet) + localAddr = local } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, fa, localAddr) // Field: rem_address. - var remoteAddr linux.SockAddrInet + var remoteAddr linux.SockAddr if t != nil { if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = *remote.(*linux.SockAddrInet) + remoteAddr = remote } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, fa, remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) @@ -465,7 +472,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se data := []seqfile.SeqData{ { - Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n"), + Buf: header, Handle: n, }, { @@ -476,6 +483,42 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netTCP implements seqfile.SeqSource for /proc/net/tcp. +// +// +stateify savable +type netTCP struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header) +} + +// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6. +// +// +stateify savable +type netTCP6 struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP6) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header) +} + // netUDP implements seqfile.SeqSource for /proc/net/udp. // // +stateify savable @@ -529,7 +572,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se localAddr = *local.(*linux.SockAddrInet) } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, linux.AF_INET, &localAddr) // Field: rem_address. var remoteAddr linux.SockAddrInet @@ -538,7 +581,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se remoteAddr = *remote.(*linux.SockAddrInet) } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, linux.AF_INET, &remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 0ef13f2f5..56e92721e 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // But for whatever crazy reason, you can still walk to the given node. for _, tg := range rpf.iops.pidns.ThreadGroups() { if leader := tg.Leader(); leader != nil { - name := strconv.FormatUint(uint64(tg.ID()), 10) + name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10) m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice) names = append(names, name) } diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c620227c9..0bd82e480 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -32,7 +32,7 @@ type directory struct { childList dentryList } -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode { dir := &directory{} dir.inode.init(dir, fs, creds, mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index 45cd42b3e..b78471c0f 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -137,7 +137,7 @@ type inode struct { impl interface{} // immutable } -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) { +func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go index 55f869798..b7f4853b3 100644 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ b/pkg/sentry/fsimpl/memfs/regular_file.go @@ -37,7 +37,7 @@ type regularFile struct { dataLen int64 } -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { file := ®ularFile{} file.inode.init(file, fs, creds, mode) file.inode.nlink = 1 // from parent directory diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index e76ee27d2..cf2a56bed 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -8,6 +8,8 @@ go_library( "error.go", "flags.go", "linux64.go", + "linux64_amd64.go", + "linux64_arm64.go", "sigset.go", "sys_aio.go", "sys_capability.go", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 72c383537..b64c49ff5 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,4 +19,4 @@ const ( _LINUX_SYSNAME = "Linux" _LINUX_RELEASE = "4.4" _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016" -)
\ No newline at end of file +) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index e6f07595c..e215ac049 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux provides syscall tables for amd64 Linux. package linux import ( @@ -228,7 +227,7 @@ var AMD64 = &kernel.SyscallTable{ 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), - 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) + 187: syscalls.Supported("readahead", Readahead), 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), @@ -316,21 +315,21 @@ var AMD64 = &kernel.SyscallTable{ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) 280: syscalls.Supported("utimensat", Utimensat), 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 283: syscalls.Supported("timerfd_create", TimerfdCreate), 284: syscalls.Supported("eventfd", Eventfd), 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), 286: syscalls.Supported("timerfd_settime", TimerfdSettime), 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 290: syscalls.Supported("eventfd2", Eventfd2), 291: syscalls.Supported("epoll_create1", EpollCreate1), 292: syscalls.Supported("dup3", Dup3), diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index 57c36bc4a..1d3b63020 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux provides syscall tables for arm64 Linux. package linux import ( diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 9f705ebca..dd3a5807f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -159,9 +159,14 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc }, outFile.Flags().NonBlocking) } + // Sendfile can't lose any data because inFD is always a regual file. + if n != 0 { + err = nil + } + // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -305,6 +310,11 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo Dup: true, }, nonBlock) + // Tee doesn't change a state of inFD, so it can't lose any data. + if n != 0 { + err = nil + } + // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 187e5410c..3aa73d911 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -31,14 +31,14 @@ type GetDentryOptions struct { // FilesystemImpl.MkdirAt(). type MkdirOptions struct { // Mode is the file mode bits for the created directory. - Mode uint16 + Mode linux.FileMode } // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). type MknodOptions struct { // Mode is the file type and mode bits for the created file. - Mode uint16 + Mode linux.FileMode // If Mode specifies a character or block device special file, DevMajor and // DevMinor are the major and minor device numbers for the created device. @@ -61,7 +61,7 @@ type OpenOptions struct { // If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the // created file. - Mode uint16 + Mode linux.FileMode } // ReadOptions contains options to FileDescription.PRead(), diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b6641ccc3..a53894c01 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -47,9 +47,6 @@ func TestExcludeBroadcast(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { t.Fatalf("AddAddress failed: %v", err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f6106f762..5993fe582 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -104,6 +104,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback func (n *NIC) enable() *tcpip.Error { n.attachLinkEndpoint() + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -372,7 +382,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { - // Don't include expired or tempory endpoints to avoid confusion and + // Don't include expired or temporary endpoints to avoid confusion and // prevent the caller from using those. switch ref.getKind() { case permanentExpired, temporary: @@ -624,21 +634,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.AddLinkAddress(n.id, src, remote) - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) - } - } - n.mu.RUnlock() - return - } - if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5c8b7977a..0b09e6517 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -208,10 +210,17 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 90c2cf1be..ff574a055 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -902,7 +902,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2dede8a9..a2e0a6e7b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -952,10 +952,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -967,10 +967,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -1016,17 +1016,33 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) @@ -1039,28 +1055,99 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } +} + +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + } + + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, + } + s.SetRouteTable(rt) + + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 8c768c299..92267ce4d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -63,7 +63,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, vv) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return @@ -338,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } - } + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a broadcast or multicast, then find all matching + // transport endpoints. + var destEps []*endpointsByNic + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, vv, id) } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } @@ -426,10 +417,11 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -437,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -445,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -491,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index faaa4a4e3..70e7575f5 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -57,6 +57,9 @@ type Error struct { // String implements fmt.Stringer.String. func (e *Error) String() string { + if e == nil { + return "<nil>" + } return e.msg } @@ -1095,6 +1098,47 @@ func (a AddressWithPrefix) String() string { return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) } +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := len(a.Address) + if a.PrefixLen <= 0 { + return Subnet{ + address: Address(strings.Repeat("\x00", addrLen)), + mask: AddressMask(strings.Repeat("\x00", addrLen)), + } + } + if a.PrefixLen >= addrLen*8 { + return Subnet{ + address: a.Address, + mask: AddressMask(strings.Repeat("\xff", addrLen)), + } + } + + sa := make([]byte, addrLen) + sm := make([]byte, addrLen) + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa[i] = a.Address[i] + sm[i] = 0xff + n -= 8 + continue + } + sm[i] = ^byte(0xff >> n) + sa[i] = a.Address[i] & sm[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(Address(sa), AddressMask(sm)) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index fb3a0a5ee..8c0aacffa 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) { t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) } } + +func TestAddressWithPrefixSubnet(t *testing.T) { + tests := []struct { + addr Address + prefixLen int + subnetAddr Address + subnetMask AddressMask + }{ + {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, + {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, + {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + } + for _, tt := range tests { + ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} + gotSubnet := ap.Subnet() + wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) + if err != nil { + t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + continue + } + if gotSubnet != wantSubnet { + t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) + } + } +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 5059ca22d..d1ba8d2ce 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -532,10 +532,11 @@ func TestBindToDeviceOption(t *testing.T) { } } -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) { c.t.Helper() payload := newPayload() @@ -553,27 +554,51 @@ func testRead(c *testContext, flow testFlow) { select { case <-ch: v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") } } + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + c.t.Fatalf("bad payload: got %x, want %x", v, payload) } } +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */) +} + func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -743,13 +768,17 @@ func TestReadOnBoundToMulticast(t *testing.T) { c.t.Fatal("SetSockOpt failed:", err) } + // Check that we receive multicast packets but not unicast or broadcast + // ones. testRead(c, flow) + testFailingRead(c, broadcast) + testFailingRead(c, unicastV4) }) } } // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and receive broadcast data on it. +// address and can receive only broadcast data. func TestV4ReadOnBoundToBroadcast(t *testing.T) { for _, flow := range []testFlow{broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -764,8 +793,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - // Test acceptance. + // Check that we receive broadcast packets but not unicast ones. + testRead(c, flow) + testFailingRead(c, unicastV4) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. testRead(c, flow) + testRead(c, unicastV4) }) } } |