diff options
42 files changed, 1066 insertions, 329 deletions
@@ -16,7 +16,7 @@ # Helpful pretty-printer. MAKEBANNER := \033[1;34mmake\033[0m -submake = echo -e '$(MAKEBANNER) $1' >&2; sh -c '$(MAKE) $1' +submake = echo -e '$(MAKEBANNER) $1' >&2; $(MAKE) $1 # Described below. OPTIONS := diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 2b789c4ec..a4bb62013 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -72,6 +72,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..a91f9f018 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -76,6 +84,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +122,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +219,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +386,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,12 +500,6 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 @@ -392,6 +512,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 95337c168..c24a8216e 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -234,6 +234,8 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. @@ -303,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -317,6 +321,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -414,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index cd5f5049e..491495b16 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -150,11 +150,9 @@ afterSymlink: return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { - _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { + if err := d.parent.updateFromGetattr(ctx); err != nil { return nil, err } - d.parent.updateFromP9Attrs(attrMask, &attr) } rp.Advance() return d.parent, nil @@ -209,17 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil // Preconditions: As for getChildLocked. !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { + if child != nil { + // Need to lock child.metadataMu because we might be updating child + // metadata. We need to hold the lock *before* getting metadata from the + // server and release it after updating local metadata. + child.metadataMu.Lock() + } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil && err != syserror.ENOENT { + if child != nil { + child.metadataMu.Unlock() + } return nil, err } if child != nil { if !file.isNil() && inoFromPath(qid.Path) == child.ino { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) + child.updateFromP9AttrsLocked(attrMask, &attr) + child.metadataMu.Unlock() return child, nil } + child.metadataMu.Unlock() if file.isNil() && child.isSynthetic() { // We have a synthetic file, and no remote file has arisen to // replace it. diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index b74d489a0..ba7f650b3 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -785,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool { // updateFromP9Attrs is called to update d's metadata after an update from the // remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() +// Precondition: d.metadataMu must be locked. +func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if mask.Mode { if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { d.metadataMu.Unlock() @@ -822,7 +822,6 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { if mask.Size { d.updateFileSizeLocked(attr.Size) } - d.metadataMu.Unlock() } // Preconditions: !d.isSynthetic() @@ -834,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { file p9file handleMuRLocked bool ) + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() d.handleMu.RLock() if !d.handle.file.isNil() { file = d.handle.file @@ -849,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { if err != nil { return err } - d.updateFromP9Attrs(attrMask, &attr) + d.updateFromP9AttrsLocked(attrMask, &attr) return nil } diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index ff81ea6e6..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -40,6 +40,8 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a92aed2c9..ec5506efc 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 1243143ea..d9394055d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP } if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) info.NumEntries++ } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index d5ca3ac56..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,6 +36,8 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..98ca7add0 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ea6ebd0e2..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 964ec8414..9856ab8c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -62,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -910,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -920,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -956,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -971,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -981,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -1014,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1025,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1035,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1050,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1066,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1082,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1093,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1104,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1112,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1127,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1138,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1149,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1163,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1171,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1183,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1194,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1203,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1214,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1225,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1236,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1247,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1259,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1271,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1283,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1295,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1309,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1344,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1356,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1368,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1379,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1391,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1400,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1411,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1419,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1444,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIPv6(t, name) @@ -1453,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1466,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1482,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1496,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1507,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1532,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1543,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIP(t, name) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index d65a89316..a9025b0ec 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fcd7f9d7f..d112757fb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cca5e70f1..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -35,5 +35,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..0482d33cf 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index ff2149250..05c16fcfe 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..4a9b04fd0 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -99,5 +99,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..414fce8e3 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -29,6 +29,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 0c740335b..64696b438 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -72,5 +72,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..8096a8f9c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 945a364a7..63ab11f8c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -15,12 +15,15 @@ package vfs2 import ( + "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Move data. var ( - n int64 - err error - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { // If both input and output are pipes, delegate to the pipe - // implementation. Otherwise, exactly one end is a pipe, which we - // ensure is consistently ordered after the non-pipe FD's locks by - // passing the pipe FD as usermem.IO to the non-pipe end. + // implementation. Otherwise, exactly one end is a pipe, which + // we ensure is consistently ordered after the non-pipe FD's + // locks by passing the pipe FD as usermem.IO to the non-pipe + // end. switch { case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) @@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } else { n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } + default: + panic("not possible") } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } - - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the splice operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. - } - if err = t.Block(inCh); err != nil { - break - } - } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + if err = dw.waitForBoth(t); err != nil { + break } } @@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Copy data. var ( - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { - n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 { - return uintptr(n), nil, nil + n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + if err = dw.waitForBoth(t); err != nil { + break + } + } + if n == 0 { + return 0, nil, err + } + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Sendfile implements linux system call sendfile(2). +func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + outFD := args[0].Int() + inFD := args[1].Int() + offsetAddr := args[2].Pointer() + count := int64(args[3].SizeT()) + + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + if !inFile.IsReadable() { + return 0, nil, syserror.EBADF + } + + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + if !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // Verify that the outFile Append flag is not set. + if outFile.StatusFlags()&linux.O_APPEND != 0 { + return 0, nil, syserror.EINVAL + } + + // Verify that inFile is a regular file or block device. This is a + // requirement; the same check appears in Linux + // (fs/splice.c:splice_direct_to_actor). + if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil { + return 0, nil, err + } else if stat.Mask&linux.STATX_TYPE == 0 || + (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { + return 0, nil, syserror.EINVAL + } + + // Copy offset if it exists. + offset := int64(-1) + if offsetAddr != 0 { + if inFile.Options().DenyPRead { + return 0, nil, syserror.ESPIPE } - if err != syserror.ErrWouldBlock || nonBlock { + if _, err := t.CopyIn(offsetAddr, &offset); err != nil { return 0, nil, err } + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if offset+count < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Validate count. This must come after offset checks. + if count < 0 { + return 0, nil, syserror.EINVAL + } + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the tee operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Copy data. + var ( + n int64 + err error + ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + // Reading from input file should never block, since it is regular or + // block device. We only need to check if writing to the output file + // can block. + nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 + if outIsPipe { + for n < count { + var spliceN int64 + if offset != -1 { + spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{}) + offset += spliceN + } else { + spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } - if err := t.Block(inCh); err != nil { - return 0, nil, err + n += spliceN + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) + } + if err != nil { + break } } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. + } else { + // Read inFile to buffer, then write the contents to outFile. + buf := make([]byte, count) + for n < count { + var readN int64 + if offset != -1 { + readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) + offset += readN + } else { + readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + } + if readN == 0 && err == io.EOF { + // We reached the end of the file. Eat the + // error and exit the loop. + err = nil + break } - if err := t.Block(outCh); err != nil { - return 0, nil, err + n += readN + if err != nil { + break + } + + // Write all of the bytes that we read. This may need + // multiple write calls to complete. + wbuf := buf[:n] + for len(wbuf) > 0 { + var writeN int64 + writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) + wbuf = wbuf[writeN:] + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForOut(t) + } + if err != nil { + // We didn't complete the write. Only + // report the bytes that were actually + // written, and rewind the offset. + notWritten := int64(len(wbuf)) + n -= notWritten + if offset != -1 { + offset -= notWritten + } + break + } + } + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) } + if err != nil { + break + } + } + } + + if offsetAddr != 0 { + // Copy out the new offset. + if _, err := t.CopyOut(offsetAddr, offset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not +// thread-safe, and does not take a reference on the vfs.FileDescriptions. +// +// Users must call destroy() when finished. +type dualWaiter struct { + inFile *vfs.FileDescription + outFile *vfs.FileDescription + + inW waiter.Entry + inCh chan struct{} + outW waiter.Entry + outCh chan struct{} +} + +// waitForBoth waits for both dw.inFile and dw.outFile to be ready. +func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { + if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if dw.inCh == nil { + dw.inW, dw.inCh = waiter.NewChannelEntry(nil) + dw.inFile.EventRegister(&dw.inW, eventMaskRead) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.inCh); err != nil { + return err + } + } + return dw.waitForOut(t) +} + +// waitForOut waits for dw.outfile to be read. +func (dw *dualWaiter) waitForOut(t *kernel.Task) error { + if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready now. Try again before blocking. + return nil } + if err := t.Block(dw.outCh); err != nil { + return err + } + } + return nil +} + +// destroy cleans up resources help by dw. No more calls to wait* can occur +// after destroy is called. +func (dw *dualWaiter) destroy() { + if dw.inCh != nil { + dw.inFile.EventUnregister(&dw.inW) + dw.inCh = nil + } + if dw.outCh != nil { + dw.outFile.EventUnregister(&dw.outW) + dw.outCh = nil } + dw.inFile = nil + dw.outFile = nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 8f497ecc7..1b2cfad7d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -44,7 +44,7 @@ func Override() { s.Table[23] = syscalls.Supported("select", Select) s.Table[32] = syscalls.Supported("dup", Dup) s.Table[33] = syscalls.Supported("dup2", Dup2) - delete(s.Table, 40) // sendfile + s.Table[40] = syscalls.Supported("sendfile", Sendfile) s.Table[41] = syscalls.Supported("socket", Socket) s.Table[42] = syscalls.Supported("connect", Connect) s.Table[43] = syscalls.Supported("accept", Accept) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index d39baf620..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -49,7 +49,8 @@ const ( type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) @@ -113,13 +114,11 @@ type conn struct { // update the state of tcb. It is immutable. tcbHook Hook - // mu protects tcb. + // mu protects all mutable state. mu sync.Mutex `state:"nosave"` - // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB - // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. lastUsed time.Time `state:".(unixTime)"` @@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool { return now.Sub(cn.lastUsed) > defaultTimeout } +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} + // ConnTrack tracks all connections created for NAT rules. Most users are -// expected to only call handlePacket and createConnFor. +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. // // ConnTrack keeps all connections in a slice of buckets, each of which holds a // linked list of tuples. This gives us some desirable properties: @@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } -// createConnFor creates a new conn for pkt. -func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg manip = manipDstOutput } conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { // Lock the buckets in the correct order. - tupleBucket := ct.bucket(tid) - replyBucket := ct.bucket(replyTID) + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) ct.mu.RLock() defer ct.mu.RUnlock() if tupleBucket < replyBucket { @@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg ct.buckets[tupleBucket].mu.Lock() } - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { ct.buckets[replyBucket].mu.Unlock() } - - return conn } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. // TODO(gvisor.dev/issue/170): Change address for Prerouting hook. func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d } // handlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. if conn == nil { - // Connection not found for the packet or the packet is invalid. - return + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return false } switch hook { @@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else if hook == conn.tcbHook { - conn.tcb.UpdateStateOutbound(tcpHeader) - } else { - conn.tcb.UpdateStateInbound(tcpHeader) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return + } + + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return + } + + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { + return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) } // bucket gets the conntrack bucket for a tupleID. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 5f647c5fe..ca1dda695 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -245,13 +245,18 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. it.mu.RLock() defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { @@ -281,6 +286,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d43f60c67..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index c2e9fd29f..aefe0e2b2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -456,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 1036b0630..05e3637f7 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -31,5 +31,6 @@ go_test( deps = [ "//pkg/log", "//pkg/p9", + "//pkg/test/testutil", ], ) diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index b7521bda7..82a46910e 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -132,7 +132,7 @@ func (a *attachPoint) Attach() (p9.File, error) { return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix) } - f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { + f, readable, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { return fd.Open(a.prefix, openFlags|mode, 0) }) if err != nil { @@ -144,7 +144,7 @@ func (a *attachPoint) Attach() (p9.File, error) { return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err) } - lf, err := newLocalFile(a, f, a.prefix, stat) + lf, err := newLocalFile(a, f, a.prefix, readable, stat) if err != nil { return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err) } @@ -212,6 +212,10 @@ type localFile struct { // opened with. file *fd.FD + // controlReadable tells whether 'file' was opened with read permissions + // during a walk. + controlReadable bool + // mode is the mode in which the file was opened. Set to invalidMode // if localFile isn't opened. mode p9.OpenFlags @@ -251,49 +255,57 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { return fd.New(d), nil } -func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, error) { +func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) { path := path.Join(parent.hostPath, name) - f, err := openAnyFile(path, func(mode int) (*fd.FD, error) { + f, readable, err := openAnyFile(path, func(mode int) (*fd.FD, error) { return fd.OpenAt(parent.file, name, openFlags|mode, 0) }) - return f, path, err + return f, path, readable, err } // openAnyFile attempts to open the file in O_RDONLY and if it fails fallsback // to O_PATH. 'path' is used for logging messages only. 'fn' is what does the // actual file open and is customizable by the caller. -func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) { +func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, error) { // Attempt to open file in the following mode in order: // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs. // Use non-blocking to prevent getting stuck inside open(2) for // FIFOs. This option has no effect on regular files. // 2. PATH: for symlinks, sockets. - modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH} + options := []struct { + mode int + readable bool + }{ + { + mode: syscall.O_RDONLY | syscall.O_NONBLOCK, + readable: true, + }, + { + mode: unix.O_PATH, + readable: false, + }, + } var err error - var file *fd.FD - for i, mode := range modes { - file, err = fn(mode) + for i, option := range options { + var file *fd.FD + file, err = fn(option.mode) if err == nil { - // openat succeeded, we're done. - break + // Succeeded opening the file, we're done. + return file, option.readable, nil } switch e := extractErrno(err); e { case syscall.ENOENT: // File doesn't exist, no point in retrying. - return nil, e + return nil, false, e } - // openat failed. Try again with next mode, preserving 'err' in case this - // was the last attempt. - log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|mode, path, err) + // File failed to open. Try again with next mode, preserving 'err' in case + // this was the last attempt. + log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|option.mode, path, err) } - if err != nil { - // All attempts to open file have failed, return the last error. - log.Debugf("Failed to open file, path: %q, err: %v", path, err) - return nil, extractErrno(err) - } - - return file, nil + // All attempts to open file have failed, return the last error. + log.Debugf("Failed to open file, path: %q, err: %v", path, err) + return nil, false, extractErrno(err) } func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) { @@ -316,18 +328,19 @@ func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, err return ft, nil } -func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) { +func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat syscall.Stat_t) (*localFile, error) { ft, err := getSupportedFileType(stat, a.conf.HostUDS) if err != nil { return nil, err } return &localFile{ - attachPoint: a, - hostPath: path, - file: file, - mode: invalidMode, - ft: ft, + attachPoint: a, + hostPath: path, + file: file, + mode: invalidMode, + ft: ft, + controlReadable: readable, }, nil } @@ -380,7 +393,7 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { // Check if control file can be used or if a new open must be created. var newFile *fd.FD - if flags == p9.ReadOnly { + if flags == p9.ReadOnly && l.controlReadable { log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath) newFile = l.file } else { @@ -518,7 +531,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { // Duplicate current file if 'names' is empty. if len(names) == 0 { - newFile, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { + newFile, readable, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { return reopenProcFd(l.file, openFlags|mode) }) if err != nil { @@ -532,10 +545,11 @@ func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { } c := &localFile{ - attachPoint: l.attachPoint, - hostPath: l.hostPath, - file: newFile, - mode: invalidMode, + attachPoint: l.attachPoint, + hostPath: l.hostPath, + file: newFile, + mode: invalidMode, + controlReadable: readable, } return []p9.QID{l.attachPoint.makeQID(stat)}, c, nil } @@ -543,7 +557,7 @@ func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { var qids []p9.QID last := l for _, name := range names { - f, path, err := openAnyFileFromParent(last, name) + f, path, readable, err := openAnyFileFromParent(last, name) if last != l { last.Close() } @@ -555,7 +569,7 @@ func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { f.Close() return nil, nil, extractErrno(err) } - c, err := newLocalFile(last.attachPoint, f, path, stat) + c, err := newLocalFile(last.attachPoint, f, path, readable, stat) if err != nil { f.Close() return nil, nil, extractErrno(err) diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index 05af7e397..5b37e6aa1 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -26,6 +26,19 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} + +var ( + allTypes = []fileType{regular, directory, symlink} + + // allConfs is set in init(). + allConfs []Config + + rwConfs = []Config{{ROMount: false}} + roConfs = []Config{{ROMount: true}} ) func init() { @@ -39,6 +52,13 @@ func init() { } } +func configTestName(config *Config) string { + if config.ROMount { + return "ROMount" + } + return "RWMount" +} + func assertPanic(t *testing.T, f func()) { defer func() { if r := recover(); r == nil { @@ -88,18 +108,6 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { return nil } -var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} - -var ( - allTypes = []fileType{regular, directory, symlink} - - // allConfs is set in init() above. - allConfs []Config - - rwConfs = []Config{{ROMount: false}} - roConfs = []Config{{ROMount: true}} -) - type state struct { root *localFile file *localFile @@ -117,42 +125,46 @@ func runAll(t *testing.T, test func(*testing.T, state)) { func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) { for _, c := range confs { - t.Logf("Config: %+v", c) - for _, ft := range types { - t.Logf("File type: %v", ft) + name := fmt.Sprintf("%s/%v", configTestName(&c), ft) + t.Run(name, func(t *testing.T) { + path, name, err := setup(ft) + if err != nil { + t.Fatalf("%v", err) + } + defer os.RemoveAll(path) - path, name, err := setup(ft) - if err != nil { - t.Fatalf("%v", err) - } - defer os.RemoveAll(path) + a, err := NewAttachPoint(path, c) + if err != nil { + t.Fatalf("NewAttachPoint failed: %v", err) + } + root, err := a.Attach() + if err != nil { + t.Fatalf("Attach failed, err: %v", err) + } - a, err := NewAttachPoint(path, c) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - root, err := a.Attach() - if err != nil { - t.Fatalf("Attach failed, err: %v", err) - } + _, file, err := root.Walk([]string{name}) + if err != nil { + root.Close() + t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) + } - _, file, err := root.Walk([]string{name}) - if err != nil { + st := state{ + root: root.(*localFile), + file: file.(*localFile), + conf: c, + ft: ft, + } + test(t, st) + file.Close() root.Close() - t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) - } - - st := state{root: root.(*localFile), file: file.(*localFile), conf: c, ft: ft} - test(t, st) - file.Close() - root.Close() + }) } } } func setup(ft fileType) (string, string, error) { - path, err := ioutil.TempDir("", "root-") + path, err := ioutil.TempDir(testutil.TmpDir(), "root-") if err != nil { return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err) } @@ -308,6 +320,32 @@ func TestUnopened(t *testing.T) { }) } +// TestOpenOPath is a regression test to ensure that a file that cannot be open +// for read is allowed to be open. This was happening because the control file +// was open with O_PATH, but Open() was not checking for it and allowing the +// control file to be reused. +func TestOpenOPath(t *testing.T) { + runCustom(t, []fileType{regular}, rwConfs, func(t *testing.T, s state) { + // Fist remove all permissions on the file. + if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil { + t.Fatalf("SetAttr(): %v", err) + } + // Then walk to the file again to open a new control file. + filename := filepath.Base(s.file.hostPath) + _, newFile, err := s.root.Walk([]string{filename}) + if err != nil { + t.Fatalf("root.Walk(%q): %v", filename, err) + } + + if newFile.(*localFile).controlReadable { + t.Fatalf("control file didn't open with O_PATH: %+v", newFile) + } + if _, _, _, err := newFile.Open(p9.ReadOnly); err != syscall.EACCES { + t.Fatalf("Open() should have failed, got: %v, wanted: EACCES", err) + } + }) +} + func SetGetAttr(l *localFile, valid p9.SetAttrMask, attr p9.SetAttr) (p9.Attr, error) { if err := l.SetAttr(valid, attr); err != nil { return p9.Attr{}, err diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index f5ac79370..f303030aa 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -263,6 +263,13 @@ func TestNATPreRedirectTCPPort(t *testing.T) { singleTest(t, NATPreRedirectTCPPort{}) } +func TestNATPreRedirectTCPOutgoing(t *testing.T) { + singleTest(t, NATPreRedirectTCPOutgoing{}) +} + +func TestNATOutRedirectTCPIncoming(t *testing.T) { + singleTest(t, NATOutRedirectTCPIncoming{}) +} func TestNATOutRedirectUDPPort(t *testing.T) { singleTest(t, NATOutRedirectUDPPort{}) } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 8562b0820..149dec2bb 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -28,6 +28,8 @@ const ( func init() { RegisterTestCase(NATPreRedirectUDPPort{}) RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectTCPOutgoing{}) + RegisterTestCase(NATOutRedirectTCPIncoming{}) RegisterTestCase(NATOutRedirectUDPPort{}) RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) @@ -91,6 +93,56 @@ func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { return connectTCP(ip, dropPort, sendloopDuration) } +// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't +// affected by PREROUTING connection tracking. +type NATPreRedirectTCPOutgoing struct{} + +// Name implements TestCase.Name. +func (NATPreRedirectTCPOutgoing) Name() string { + return "NATPreRedirectTCPOutgoing" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP) error { + // Redirect all incoming TCP traffic to a closed port. + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return connectTCP(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP) error { + return listenTCP(acceptPort, sendloopDuration) +} + +// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't +// affected by OUTPUT connection tracking. +type NATOutRedirectTCPIncoming struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectTCPIncoming) Name() string { + return "NATOutRedirectTCPIncoming" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP) error { + // Redirect all outgoing TCP traffic to a closed port. + if err := natTable("-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return listenTCP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP) error { + return connectTCP(ip, acceptPort, sendloopDuration) +} + // NATOutRedirectUDPPort tests that packets are redirected to different port. type NATOutRedirectUDPPort struct{} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index c06a75ada..15e5fd223 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -640,11 +640,13 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_socket_test", + vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_test", + vfs2 = "True", ) syscall_test( diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index ce54dc064..8d6e5c913 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -262,6 +262,27 @@ TEST_P(RawSocketTest, SendWithoutConnectFails) { SyscallFailsWithErrno(EDESTADDRREQ)); } +// Wildcard Bind. +TEST_P(RawSocketTest, BindToWildcard) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + struct sockaddr_storage addr; + addr = {}; + + // We don't set ports because raw sockets don't have a notion of ports. + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_any; + } + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); +} + // Bind to localhost. TEST_P(RawSocketTest, BindToLocalhost) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); diff --git a/tools/bazel.mk b/tools/bazel.mk index b6f23fe7e..e27e907ab 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -30,9 +30,6 @@ DOCKER_SOCKET := /var/run/docker.sock # Bazel flags. OPTIONS += --test_output=errors --keep_going --verbose_failures=true -ifneq ($(AUTH_CREDENTIALS),) -OPTIONS += --auth_credentials=${AUTH_CREDENTIALS} --config=remote -endif BAZEL := bazel $(STARTUP_OPTIONS) # Non-configurable. @@ -71,7 +68,7 @@ SHELL=/bin/bash -o pipefail bazel-server-start: load-default ## Starts the bazel server. @mkdir -p $(BAZEL_CACHE) @mkdir -p $(GCLOUD_CONFIG) - @if docker ps --all | grep $(DOCKER_NAME); then docker rm $(DOCKER_NAME); fi + @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi docker run -d --rm \ --init \ --name $(DOCKER_NAME) \ @@ -102,7 +99,7 @@ bazel-server: ## Ensures that the server exists. Used as an internal target. @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start .PHONY: bazel-server -build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) $(STARTUP_OPTIONS) build $(OPTIONS) $(TARGETS)' +build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) $(TARGETS)' build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 4886efddf..68d759083 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -9,11 +9,9 @@ automatically generating code to marshal go data structures to memory. `binary.Marshal` by moving the go runtime reflection necessary to marshal a struct to compile-time. -`go_marshal` automatically generates implementations for `abi.Marshallable` and -`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall -implementations) can directly invoke `safemem.Reader.ReadToBlocks` and -`safemem.Writer.WriteFromBlocks`. Data structures that require custom -serialization will have manual implementations for these interfaces. +`go_marshal` automatically generates implementations for `marshal.Marshallable` +and `safemem.{Reader,Writer}`. Data structures that require custom serialization +will have manual implementations for these interfaces. Data structures can be flagged for code generation by adding a struct-level comment `// +marshal`. diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go index ebcf130ae..d93edda8b 100644 --- a/tools/go_marshal/primitive/primitive.go +++ b/tools/go_marshal/primitive/primitive.go @@ -17,10 +17,22 @@ package primitive import ( + "io" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/marshal" ) +// Int8 is a marshal.Marshallable implementation for int8. +// +// +marshal slice:Int8Slice:inner +type Int8 int8 + +// Uint8 is a marshal.Marshallable implementation for uint8. +// +// +marshal slice:Uint8Slice:inner +type Uint8 uint8 + // Int16 is a marshal.Marshallable implementation for int16. // // +marshal slice:Int16Slice:inner @@ -51,6 +63,66 @@ type Int64 int64 // +marshal slice:Uint64Slice:inner type Uint64 uint64 +// ByteSlice is a marshal.Marshallable implementation for []byte. +// This is a convenience wrapper around a dynamically sized type, and can't be +// embedded in other marshallable types because it breaks assumptions made by +// go-marshal internals. It violates the "no dynamically-sized types" +// constraint of the go-marshal library. +type ByteSlice []byte + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (b *ByteSlice) SizeBytes() int { + return len(*b) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (b *ByteSlice) MarshalBytes(dst []byte) { + copy(dst, *b) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (b *ByteSlice) UnmarshalBytes(src []byte) { + copy(*b, src) +} + +// Packed implements marshal.Marshallable.Packed. +func (b *ByteSlice) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *ByteSlice) MarshalUnsafe(dst []byte) { + b.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *ByteSlice) UnmarshalUnsafe(src []byte) { + b.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + return task.CopyInBytes(addr, *b) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + return task.CopyOutBytes(addr, *b) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + return task.CopyOutBytes(addr, (*b)[:limit]) +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(*b) + return int64(n), err +} + +var _ marshal.Marshallable = (*ByteSlice)(nil) + // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. diff --git a/tools/nogo/build.go b/tools/nogo/build.go index fb9f17b62..433d13738 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -31,10 +31,10 @@ var ( ) // findStdPkg needs to find the bundled standard library packages. -func findStdPkg(path, GOOS, GOARCH string) (io.ReadCloser, error) { +func (i *importer) findStdPkg(path string) (io.ReadCloser, error) { if path == "C" { // Cgo builds cannot be analyzed. Skip. return nil, ErrSkip } - return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path)) + return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", i.GOOS, i.GOARCH, path)) } diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 6560b57c8..d399079c5 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -28,8 +28,10 @@ def _nogo_aspect_impl(target, ctx): else: return [NogoInfo()] - # Construct the Go environment from the go_context.env dictionary. - env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_context(ctx).env.items()]) + go_ctx = go_context(ctx) + + # Construct the Go environment from the go_ctx.env dictionary. + env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) # Start with all target files and srcs as input. inputs = target.files.to_list() + srcs @@ -45,7 +47,7 @@ def _nogo_aspect_impl(target, ctx): "#!/bin/bash", "%s %s tool objdump %s > %s\n" % ( env_prefix, - go_context(ctx).go.path, + go_ctx.go.path, [f.path for f in binaries if f.path.endswith(".a")][0], disasm_file.path, ), @@ -53,7 +55,7 @@ def _nogo_aspect_impl(target, ctx): ctx.actions.run( inputs = binaries, outputs = [disasm_file], - tools = go_context(ctx).runfiles, + tools = go_ctx.runfiles, mnemonic = "GoObjdump", progress_message = "Objdump %s" % target.label, executable = dumper, @@ -70,9 +72,11 @@ def _nogo_aspect_impl(target, ctx): ImportPath = importpath, GoFiles = [src.path for src in srcs if src.path.endswith(".go")], NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], - GOOS = go_context(ctx).goos, - GOARCH = go_context(ctx).goarch, - Tags = go_context(ctx).tags, + # Google's internal build system needs a bit more help to find std. + StdZip = go_ctx.std_zip.short_path if hasattr(go_ctx, "std_zip") else "", + GOOS = go_ctx.goos, + GOARCH = go_ctx.goarch, + Tags = go_ctx.tags, FactMap = {}, # Constructed below. ImportMap = {}, # Constructed below. FactOutput = facts.path, @@ -110,7 +114,7 @@ def _nogo_aspect_impl(target, ctx): ctx.actions.run( inputs = inputs, outputs = [facts], - tools = go_context(ctx).runfiles, + tools = go_ctx.runfiles, executable = ctx.files._nogo[0], mnemonic = "GoStaticAnalysis", progress_message = "Analyzing %s" % target.label, diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index 5ee586c3e..ea1e97076 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -55,6 +55,7 @@ type pkgConfig struct { FactMap map[string]string FactOutput string Objdump string + StdZip string } // loadFacts finds and loads facts per FactMap. @@ -111,7 +112,7 @@ func (i *importer) Import(path string) (*types.Package, error) { if !ok { // Not found in the import path. Attempt to find the package // via the standard library. - rc, err = findStdPkg(path, i.GOOS, i.GOARCH) + rc, err = i.findStdPkg(path) } else { // Open the file. rc, err = os.Open(realPath) |