diff options
Diffstat (limited to 'pkg')
47 files changed, 1121 insertions, 289 deletions
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index 5539466ff..f690ef5ad 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -95,9 +95,14 @@ type DeviceFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DeviceFD) Release(context.Context) { +func (fd *DeviceFD) Release(ctx context.Context) { if fd.fs != nil { + fd.fs.conn.mu.Lock() fd.fs.conn.connected = false + fd.fs.conn.mu.Unlock() + + fd.fs.VFSFilesystem().DecRef(ctx) + fd.fs = nil } } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index a768b1a80..b3573f80d 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -218,6 +218,7 @@ func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt conn: conn, } + fs.VFSFilesystem().IncRef() fuseFD.fs = fs return fs, nil diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 03ce51df7..16abc59dc 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -1026,7 +1026,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open // step is required even if !d.cachedMetadataAuthoritative() because // d.mappings has to be updated. // d.metadataMu has already been acquired if trunc == true. - d.updateFileSizeLocked(0) + d.updateSizeLocked(0) if d.cachedMetadataAuthoritative() { d.touchCMtimeLocked() diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 0e21c31a4..aaad9c0d9 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -833,7 +833,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { atomic.StoreUint32(&d.nlink, uint32(attr.NLink)) } if mask.Size { - d.updateFileSizeLocked(attr.Size) + d.updateSizeLocked(attr.Size) } } @@ -987,7 +987,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // d.size should be kept up to date, and privatized // copy-on-write mappings of truncated pages need to be // invalidated, even if InteropModeShared is in effect. - d.updateFileSizeLocked(stat.Size) + d.updateSizeLocked(stat.Size) } } if d.fs.opts.interop == InteropModeShared { @@ -1024,8 +1024,31 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } +// doAllocate performs an allocate operation on d. Note that d.metadataMu will +// be held when allocate is called. +func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate func() error) error { + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Allocating a smaller size is a noop. + size := offset + length + if d.cachedMetadataAuthoritative() && size <= d.size { + return nil + } + + err := allocate() + if err != nil { + return err + } + d.updateSizeLocked(size) + if d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + return nil +} + // Preconditions: d.metadataMu must be locked. -func (d *dentry) updateFileSizeLocked(newSize uint64) { +func (d *dentry) updateSizeLocked(newSize uint64) { d.dataMu.Lock() oldSize := d.size atomic.StoreUint64(&d.size, newSize) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index a2e9342d5..24f03ee94 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -79,28 +79,11 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { // Allocate implements vfs.FileDescriptionImpl.Allocate. func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - - // Allocating a smaller size is a noop. - size := offset + length - if d.cachedMetadataAuthoritative() && size <= d.size { - return nil - } - - d.handleMu.RLock() - err := d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) - d.handleMu.RUnlock() - if err != nil { - return err - } - d.dataMu.Lock() - atomic.StoreUint64(&d.size, size) - d.dataMu.Unlock() - if d.cachedMetadataAuthoritative() { - d.touchCMtimeLocked() - } - return nil + return d.doAllocate(ctx, offset, length, func() error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) } // PRead implements vfs.FileDescriptionImpl.PRead. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 3c39aa9b7..dc960e5bf 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -135,6 +136,16 @@ func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { fd.fileDescription.EventUnregister(e) } +func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + if fd.isRegularFile { + d := fd.dentry() + return d.doAllocate(ctx, offset, length, func() error { + return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) + } + return fd.FileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if fd.seekable && offset < 0 { diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index afd57e015..db8536f26 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -560,12 +560,7 @@ func (f *fileDescription) Release(context.Context) { // Allocate implements vfs.FileDescriptionImpl.Allocate. func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { - if !f.inode.seekable { - return syserror.ESPIPE - } - - // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds. - return syserror.EOPNOTSUPP + return unix.Fallocate(f.inode.hostFD, uint32(mode), int64(offset), int64(length)) } // PRead implements vfs.FileDescriptionImpl.PRead. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 4a3dee631..89ed265dc 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -658,9 +658,6 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error fs.mu.Lock() defer fs.mu.Unlock() - // Store the name before walkExistingLocked as rp will be advanced past the - // name in the following call. - name := rp.Component() vfsd, inode, err := fs.walkExistingLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { @@ -691,7 +688,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } - if err := parentDentry.inode.RmDir(ctx, name, vfsd); err != nil { + if err := parentDentry.inode.RmDir(ctx, d.name, vfsd); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } @@ -771,9 +768,6 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error fs.mu.Lock() defer fs.mu.Unlock() - // Store the name before walkExistingLocked as rp will be advanced past the - // name in the following call. - name := rp.Component() vfsd, _, err := fs.walkExistingLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { @@ -799,7 +793,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.Unlink(ctx, name, vfsd); err != nil { + if err := parentDentry.inode.Unlink(ctx, d.name, vfsd); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index 74cfd3799..6e04705c7 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -147,6 +147,16 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux return stat, nil } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + return err + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Allocate(ctx, mode, offset, length) +} + // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index a1fe906ed..26b117ca4 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -185,8 +185,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Start: child.lowerMerkleVD, }, &vfs.GetXattrOptions{ Name: merkleOffsetInParentXattr, - // Offset is a 32 bit integer. - Size: sizeOfInt32, + Size: sizeOfStringInt32, }) // The Merkle tree file for the child should have been created and @@ -227,7 +226,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // the size of all its children's root hashes. dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ Name: merkleSizeXattr, - Size: sizeOfInt32, + Size: sizeOfStringInt32, }) // The Merkle tree file for the child should have been created and @@ -372,6 +371,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, Path: fspath.Parse(childMerkleFilename), }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, }) if err != nil { return nil, err diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 3e0bcd02b..9182df317 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -57,8 +57,9 @@ const merkleOffsetInParentXattr = "user.merkle.offset" // whole file. For a directory, it's the size of all its children's root hashes. const merkleSizeXattr = "user.merkle.size" -// sizeOfInt32 is the size in bytes for a 32 bit integer in extended attributes. -const sizeOfInt32 = 4 +// sizeOfStringInt32 is the size for a 32 bit integer stored as string in +// extended attributes. The maximum value of a 32 bit integer is 10 digits. +const sizeOfStringInt32 = 10 // noCrashOnVerificationFailure indicates whether the sandbox should panic // whenever verification fails. If true, an error is returned instead of @@ -636,7 +637,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // dataSize is the size of the whole file. dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ Name: merkleSizeXattr, - Size: sizeOfInt32, + Size: sizeOfStringInt32, }) // The Merkle tree file for the child should have been created and diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index f223d59e1..f61039f5b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -67,6 +67,11 @@ func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlag return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (*VFSPipe) Allocate(context.Context, uint64, uint64, uint64) error { + return syserror.ESPIPE +} + // Open opens the pipe represented by vp. func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 8a1d52ebf..97bc6027f 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -97,11 +97,6 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// Allocate implements vfs.FileDescriptionImpl.Allocate. -func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { - return syserror.ENODEV -} - // PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3e1735079..871ea80ee 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -146,6 +146,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { case stack.FilterTable: table = stack.EmptyFilterTable() case stack.NATTable: + if ipv6 { + nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)") + return syserr.ErrInvalidArgument + } table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 0bfd6c1f4..844acfede 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,17 +97,33 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the stack.Check codepath as matchers are added. + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return false, false - } + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber { + return false, false } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 7ed05461d..63201201c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,19 +94,33 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. - if netHeader.TransportProtocol() != header.UDPProtocolNumber { - return false, false - } + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false } + + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index b462924af..816c89cfa 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1185,7 +1185,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam var v tcpip.LingerOption var linger linux.Linger if err := ep.GetSockOpt(&v); err != nil { - return &linger, nil + return nil, syserr.TranslateNetstackError(err) } if v.Enabled { @@ -1768,10 +1768,16 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int case linux.SOL_IP: return setSockOptIP(t, s, ep, name, optVal) + case linux.SOL_PACKET: + // gVisor doesn't support any SOL_PACKET options just return not + // supported. Returning nil here will result in tcpdump thinking AF_PACKET + // features are supported and proceed to use them and break. + t.Kernel().EmitUnimplementedEvent(t) + return syserr.ErrProtocolNotAvailable + case linux.SOL_UDP, linux.SOL_ICMPV6, - linux.SOL_RAW, - linux.SOL_PACKET: + linux.SOL_RAW: t.Kernel().EmitUnimplementedEvent(t) } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 08504560c..d6fc03520 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -746,6 +746,9 @@ type baseEndpoint struct { // path is not empty if the endpoint has been bound, // or may be used if the endpoint is connected. path string + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // EventRegister implements waiter.Waitable.EventRegister. @@ -841,8 +844,14 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return n, err } -// SetSockOpt sets a socket option. Currently not supported. -func (e *baseEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { +// SetSockOpt sets a socket option. +func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.LingerOption: + e.Lock() + e.linger = *v + e.Unlock() + } return nil } @@ -945,8 +954,11 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch o := opt.(type) { case *tcpip.LingerOption: + e.Lock() + *o = e.linger + e.Unlock() return nil default: diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go index cd6f4dd94..bfcf44b6f 100644 --- a/pkg/sentry/syscalls/linux/sys_sched.go +++ b/pkg/sentry/syscalls/linux/sys_sched.go @@ -30,7 +30,7 @@ const ( // // +marshal type SchedParam struct { - schedPriority int64 + schedPriority int32 } // SchedGetparam implements linux syscall sched_getparam(2). diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 2b29a3c3f..73bb36d3e 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -326,6 +326,9 @@ type FileDescriptionImpl interface { // Allocate grows the file to offset + length bytes. // Only mode == 0 is supported currently. // + // Allocate should return EISDIR on directories, ESPIPE on pipes, and ENODEV on + // other files where it is not supported. + // // Preconditions: The FileDescription was opened for writing. Allocate(ctx context.Context, mode, offset, length uint64) error diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 68b80a951..78da16bac 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -57,7 +57,11 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err } // Allocate implements FileDescriptionImpl.Allocate analogously to -// fallocate called on regular file, directory or FIFO in Linux. +// fallocate called on an invalid type of file in Linux. +// +// Note that directories can rely on this implementation even though they +// should technically return EISDIR. Allocate should never be called for a +// directory, because it requires a writable fd. func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { return syserror.ENODEV } diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD new file mode 100644 index 000000000..2adee9288 --- /dev/null +++ b/pkg/tcpip/header/parse/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "parse", + srcs = ["parse.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go new file mode 100644 index 000000000..522135557 --- /dev/null +++ b/pkg/tcpip/header/parse/parse.go @@ -0,0 +1,166 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package parse provides utilities to parse packets. +package parse + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// ARP populates pkt's network header with an ARP header found in +// pkt.Data. +// +// Returns true if the header was successfully parsed. +func ARP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.NetworkHeader().Consume(header.ARPSize) + if ok { + pkt.NetworkProtocolNumber = header.ARPProtocolNumber + } + return ok +} + +// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network +// header with the IPv4 header. +// +// Returns true if the header was successfully parsed. +func IPv4(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return false + } + ipHdr := header.IPv4(hdr) + + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return false + } + ipHdr = header.IPv4(hdr) + + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return true +} + +// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network +// header with the IPv6 header. +func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, 0, 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + var nextHdr tcpip.TransportProtocolNumber + var extensionsSize int + +traverseExtensions: + for { + extHdr, done, err := it.Next() + if err != nil { + break + } + + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + if fragID == 0 && fragOffset == 0 && !fragMore { + fragID = extHdr.ID() + fragOffset = extHdr.FragmentOffset() + fragMore = extHdr.More() + } + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + return nextHdr, fragID, fragOffset, fragMore, true +} + +// UDP parses a UDP packet found in pkt.Data and populates pkt's transport +// header with the UDP header. +// +// Returns true if the header was successfully parsed. +func UDP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok +} + +// TCP parses a TCP packet found in pkt.Data and populates pkt's transport +// header with the TCP header. +// +// Returns true if the header was successfully parsed. +func TCP(pkt *stack.PacketBuffer) bool { + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset + } + + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok +} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 7cbc305e7..4aac12a8c 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 4fb127978..560477926 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var transProto uint8 src := tcpip.Address("unknown") dst := tcpip.Address("unknown") - id := 0 - size := uint16(0) + var size uint16 + var id uint32 var fragmentOffset uint16 var moreFragments bool - // Examine the packet using a new VV. Backing storage must not be written. - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - + // Clone the packet buffer to not modify the original. + // + // We don't clone the original packet buffer so that the new packet buffer + // does not have any of its headers set. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return } - ipv4 := header.IPv4(hdr) + + ipv4 := header.IPv4(pkt.NetworkHeader().View()) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) - id = int(ipv4.ID()) + id = uint32(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) + proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return } - ipv6 := header.IPv6(hdr) + + ipv6 := header.IPv6(pkt.NetworkHeader().View()) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() - transProto = ipv6.NextHeader() + transProto = uint8(proto) size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + id = fragID + moreFragments = fragMore + fragmentOffset = fragOffset case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { + if parse.ARP(pkt) { return } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + + arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( "%s arp %s (%s) -> %s (%s) valid:%t", prefix, @@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { + if ok := parse.UDP(pkt); !ok { break } - udp := header.UDP(hdr) + + udp := header.UDP(pkt.TransportHeader().View()) if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { + if ok := parse.TCP(pkt); !ok { break } - tcp := header.TCP(hdr) + + tcp := header.TCP(pkt.TransportHeader().View()) if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > vv.Size() && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) + if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 82c073e32..b40dde96b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7aaee08c4..cb9225bd7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -234,11 +235,7 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - _, ok = pkt.NetworkHeader().Consume(header.ARPSize) - if !ok { - return 0, false, false - } - return 0, false, true + return 0, false, parse.ARP(pkt) } // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index c82593e71..f9c2aa980 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -13,6 +13,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index f4394749d..59c3101b5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -238,11 +239,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. - // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than - // only NATted packets, but removing this check short circuits broadcasts - // before they are sent out to other hosts. + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) @@ -298,7 +301,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, err } - // Slow Path as we are dropping some packets in the batch degrade to + // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -527,37 +530,14 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return 0, false, false - } - ipHdr := header.IPv4(hdr) - - // Header may have options, determine the true header length. - headerLen := int(ipHdr.HeaderLength()) - if headerLen < header.IPv4MinimumSize { - // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in - // order for the packet to be valid. Figure out if we want to reject this - // case. - headerLen = header.IPv4MinimumSize - } - hdr, ok = pkt.NetworkHeader().Consume(headerLen) - if !ok { + if ok := parse.IPv4(pkt); !ok { return 0, false, false } - ipHdr = header.IPv4(hdr) - - // If this is a fragment, don't bother parsing the transport header. - parseTransportHeader := true - if ipHdr.More() || ipHdr.FragmentOffset() != 0 { - parseTransportHeader = false - } - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) - return ipHdr.TransportProtocol(), parseTransportHeader, true + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } // calculateMTU calculates the network-layer payload MTU based on the link-layer diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 5e50558e8..c1a560914 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -1046,3 +1046,211 @@ func TestReceiveFragments(t *testing.T) { }) } } + +func TestWritePacketsStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + linkEP stack.LinkEndpoint + expectSent int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets - 1}, + expectSent: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: 0, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets - 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rt := buildRoute(t, nil, test.linkEP) + + var pbl stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(1).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pbl.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, err := rt.WritePackets(nil, pbl, stack.NetworkHeaderParams{}) + if err != nil { + t.Fatal(err) + } + + got := int(rt.Stats().IP.PacketsSent.Value()) + if got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got != nWritten { + t.Errorf("sent %d packets, WritePackets returned %d", got, nWritten) + } + }) + } +} + +func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + s.CreateNIC(1, linkEP) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return rt +} + +// limitedEP is a link endpoint that writes up to a certain number of packets +// before returning errors. +type limitedEP struct { + limit int +} + +// MTU implements LinkEndpoint.MTU. +func (*limitedEP) MTU() uint32 { return 0 } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*limitedEP) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if ep.limit == 0 { + return 0, tcpip.ErrInvalidEndpointState + } + nWritten := ep.limit + if nWritten > pkts.Len() { + nWritten = pkts.Len() + } + ep.limit -= nWritten + return nWritten, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*limitedEP) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*limitedEP) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bcc64994e..cd5fe3ea8 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -13,6 +13,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e821a8bff..a4a4d6a21 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -107,6 +108,31 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + return nil + } + + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() @@ -138,9 +164,46 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe e.addIPHeader(r, pb, params) } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ + } + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet @@ -169,6 +232,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + return + } + for firstHeader := true; ; firstHeader = false { extHdr, done, err := it.Next() if err != nil { @@ -504,75 +575,14 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return 0, false, false } - ipHdr := header.IPv6(hdr) - - // dataClone consists of: - // - Any IPv6 header bytes after the first 40 (i.e. extensions). - // - The transport header, if present. - // - Any other payload data. - views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) - - // Iterate over the IPv6 extensions to find their length. - // - // Parsing occurs again in HandlePacket because we don't track the - // extensions in PacketBuffer. Unfortunately, that means HandlePacket - // has to do the parsing work again. - var nextHdr tcpip.TransportProtocolNumber - foundNext := true - extensionsSize := 0 -traverseExtensions: - for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { - if err != nil { - break - } - // If we exhaust the extension list, the entire packet is the IPv6 header - // and (possibly) extensions. - if done { - extensionsSize = dataClone.Size() - foundNext = false - break - } - - switch extHdr := extHdr.(type) { - case header.IPv6FragmentExtHdr: - // If this is an atomic fragment, we don't have to treat it specially. - if !extHdr.More() && extHdr.FragmentOffset() == 0 { - continue - } - // This is a non-atomic fragment and has to be re-assembled before we can - // examine the payload for a transport header. - foundNext = false - - case header.IPv6RawPayloadHeader: - // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() - nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) - break traverseExtensions - - default: - // Any other extension is a no-op, keep looping until we find the payload. - } - } - - // Put the IPv6 header with extensions in pkt.NetworkHeader(). - hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) - if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) - } - ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber - return nextHdr, foundNext, true + return proto, !fragMore && fragOffset == 0, true } // calculateMTU calculates the network-layer payload MTU based on the link-layer diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 5f9822c49..354d3b60d 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -1709,3 +1709,211 @@ func TestInvalidIPv6Fragments(t *testing.T) { }) } } + +func TestWritePacketsStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + linkEP stack.LinkEndpoint + expectSent int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets - 1}, + expectSent: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: 0, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets - 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rt := buildRoute(t, nil, test.linkEP) + + var pbl stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(1).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pbl.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, err := rt.WritePackets(nil, pbl, stack.NetworkHeaderParams{}) + if err != nil { + t.Fatal(err) + } + + got := int(rt.Stats().IP.PacketsSent.Value()) + if got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got != nWritten { + t.Errorf("sent %d packets, WritePackets returned %d", got, nWritten) + } + }) + } +} + +func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + s.CreateNIC(1, linkEP) + const ( + src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ) + s.AddAddress(1, ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return rt +} + +// limitedEP is a link endpoint that writes up to a certain number of packets +// before returning errors. +type limitedEP struct { + limit int +} + +// MTU implements LinkEndpoint.MTU. +func (*limitedEP) MTU() uint32 { return 0 } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*limitedEP) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if ep.limit == 0 { + return 0, tcpip.ErrInvalidEndpointState + } + nWritten := ep.limit + if nWritten > pkts.Len() { + nWritten = pkts.Len() + } + ep.limit -= nWritten + return nWritten, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*limitedEP) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*limitedEP) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 0e33cbe92..b6ef04d32 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -57,7 +57,72 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - tables: [numTables]Table{ + v4Tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + v6Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -166,25 +231,20 @@ func EmptyNATTable() Table { // GetTable returns a table by name. func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - // TODO(gvisor.dev/issue/3549): Enable IPv6. - if ipv6 { - return Table{}, false - } id, ok := nameToID[name] if !ok { return Table{}, false } it.mu.RLock() defer it.mu.RUnlock() - return it.tables[id], true + if ipv6 { + return it.v6Tables[id], true + } + return it.v4Tables[id], true } // ReplaceTable replaces or inserts table by name. func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - // TODO(gvisor.dev/issue/3549): Enable IPv6. - if ipv6 { - return tcpip.ErrInvalidOptionValue - } id, ok := nameToID[name] if !ok { return tcpip.ErrInvalidOptionValue @@ -198,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Err it.startReaper(reaperDelay) } it.modified = true - it.tables[id] = table + if ipv6 { + it.v6Tables[id] = table + } else { + it.v4Tables[id] = table + } return nil } @@ -221,8 +285,17 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// which address and nicName can be gathered. Currently, address is only +// needed for prerouting and nicName is only needed for output. +// +// TODO(gvisor.dev/issue/170): Dropped packets should be counted. +// // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { + if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { + return true + } // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() @@ -243,9 +316,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr if tableID == natID && pkt.NatDone { continue } - table := it.tables[tableID] + var table Table + if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber { + table = it.v6Tables[tableID] + } else { + table = it.v4Tables[tableID] + } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -256,7 +334,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -351,11 +429,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { case RuleAccept: return chainAccept @@ -372,7 +450,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -398,11 +476,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { + if !rule.Filter.match(pkt, hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -421,7 +499,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fbbd2f50f..093ee6881 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "strings" "sync" @@ -81,26 +82,25 @@ const ( // // +stateify savable type IPTables struct { - // mu protects tables, priorities, and modified. + // mu protects v4Tables, v6Tables, and modified. mu sync.RWMutex - - // tables maps tableIDs to tables. Holds builtin tables only, not user - // tables. mu must be locked for accessing. - tables [numTables]Table - - // priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. mu needs to be locked for accessing. - priorities [NumHooks][]tableID - + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. mu must be locked for accessing. + v4Tables [numTables]Table + v6Tables [numTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. It is immutable. + priorities [NumHooks][]tableID + connections ConnTrack - // reaperDone can be signalled to stop the reaper goroutine. + // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} } @@ -148,7 +148,7 @@ type Rule struct { Target Target } -// IPHeaderFilter holds basic IP filtering data common to every rule. +// IPHeaderFilter performs basic IP header matching common to every rule. // // +stateify savable type IPHeaderFilter struct { @@ -196,16 +196,43 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } -// match returns whether hdr matches the filter. -func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. +// match returns whether pkt matches the filter. +// +// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 +// or IPv6 header length. +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { + // Extract header fields. + var ( + // TODO(gvisor.dev/issue/170): Support other filter fields. + transProto tcpip.TransportProtocolNumber + dstAddr tcpip.Address + srcAddr tcpip.Address + ) + switch proto := pkt.NetworkProtocolNumber; proto { + case header.IPv4ProtocolNumber: + hdr := header.IPv4(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + case header.IPv6ProtocolNumber: + hdr := header.IPv6(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + default: + panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto)) + } + // Check the transport protocol. - if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + if fl.CheckProtocol && fl.Protocol != transProto { return false } - // Check the source and destination IPs. - if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + // Check the addresses. + if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) || + !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) { return false } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 1f1a1426b..821d3feb9 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1282,9 +1282,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } - // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 17b8beebb..1932aaeb7 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -80,7 +80,7 @@ type PacketBuffer struct { // data are held in the same underlying buffer storage. header buffer.Prependable - // NetworkProtocol is only valid when NetworkHeader is set. + // NetworkProtocolNumber is only valid when NetworkHeader is set. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol // numbers in registration APIs that take a PacketBuffer. NetworkProtocolNumber tcpip.NetworkProtocolNumber diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index ad71ff3b6..31116309e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -74,6 +74,8 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -344,9 +346,14 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -415,8 +422,17 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 8bd4e5e37..072601d2d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -83,6 +83,8 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -298,10 +300,16 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + ep.mu.Lock() + ep.linger = *v + ep.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -366,8 +374,17 @@ func (ep *endpoint) LastError() *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + ep.mu.Lock() + *o = ep.linger + ep.mu.Unlock() + return nil + + default: + return tcpip.ErrNotSupported + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index fb03e6047..e37c00523 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -84,6 +84,8 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -511,10 +513,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -577,8 +585,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 234fb95ce..4778e7b1c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -69,6 +69,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index faea7f2bb..120483838 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1317,14 +1317,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { - // The endpoint cannot be written to if it's not connected. - if !e.EndpointState().connected() { - switch e.EndpointState() { - case StateError: - return 0, e.HardError - default: - return 0, tcpip.ErrClosedForSend - } + switch s := e.EndpointState(); { + case s == StateError: + return 0, e.HardError + case !s.connecting() && !s.connected(): + return 0, tcpip.ErrClosedForSend + case s.connecting(): + // As per RFC793, page 56, a send request arriving when in connecting + // state, can be queued to be completed after the state becomes + // connected. Return an error code for the caller of endpoint Write to + // try again, until the connection handshake is complete. + return 0, tcpip.ErrWouldBlock } // Check if the connection has already been closed for sends. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 63ec12be8..74a17af79 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -506,22 +507,7 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TCP header is variable length, peek at it first. - hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) - if !ok { - return false - } - - // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of - // packets. - hdrLen = offset - } - - _, ok = pkt.TransportHeader().Consume(hdrLen) - return ok + return parse.TCP(pkt) } // NewProtocol returns a TCP transport protocol. diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index b5d2d0ba6..c78549424 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b572c39db..518f636f0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -154,6 +154,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // +stateify savable @@ -810,6 +813,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -966,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() + case *tcpip.LingerOption: + e.mu.RLock() + *o = e.linger + e.mu.RUnlock() + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 3f87e8057..7d6b91a75 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/waiter" @@ -219,8 +220,7 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) - return ok + return parse.UDP(pkt) } // NewProtocol returns a UDP transport protocol. diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index b7f873392..06fb823f6 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -332,13 +332,13 @@ func PollContext(ctx context.Context, cb func() error) error { } // WaitForHTTP tries GET requests on a port until the call succeeds or timeout. -func WaitForHTTP(port int, timeout time.Duration) error { +func WaitForHTTP(ip string, port int, timeout time.Duration) error { cb := func() error { c := &http.Client{ // Calculate timeout to be able to do minimum 5 attempts. Timeout: timeout / 5, } - url := fmt.Sprintf("http://localhost:%d/", port) + url := fmt.Sprintf("http://%s:%d/", ip, port) resp, err := c.Get(url) if err != nil { log.Printf("Waiting %s: %v", url, err) |