diff options
Diffstat (limited to 'pkg')
28 files changed, 762 insertions, 232 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 2b789c4ec..a4bb62013 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -72,6 +72,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..a91f9f018 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -76,6 +84,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +122,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +219,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +386,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,12 +500,6 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 @@ -392,6 +512,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 95337c168..c24a8216e 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -234,6 +234,8 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. @@ -303,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -317,6 +321,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -414,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index cd5f5049e..491495b16 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -150,11 +150,9 @@ afterSymlink: return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { - _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { + if err := d.parent.updateFromGetattr(ctx); err != nil { return nil, err } - d.parent.updateFromP9Attrs(attrMask, &attr) } rp.Advance() return d.parent, nil @@ -209,17 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil // Preconditions: As for getChildLocked. !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { + if child != nil { + // Need to lock child.metadataMu because we might be updating child + // metadata. We need to hold the lock *before* getting metadata from the + // server and release it after updating local metadata. + child.metadataMu.Lock() + } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil && err != syserror.ENOENT { + if child != nil { + child.metadataMu.Unlock() + } return nil, err } if child != nil { if !file.isNil() && inoFromPath(qid.Path) == child.ino { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) + child.updateFromP9AttrsLocked(attrMask, &attr) + child.metadataMu.Unlock() return child, nil } + child.metadataMu.Unlock() if file.isNil() && child.isSynthetic() { // We have a synthetic file, and no remote file has arisen to // replace it. diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index b74d489a0..ba7f650b3 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -785,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool { // updateFromP9Attrs is called to update d's metadata after an update from the // remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() +// Precondition: d.metadataMu must be locked. +func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if mask.Mode { if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { d.metadataMu.Unlock() @@ -822,7 +822,6 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { if mask.Size { d.updateFileSizeLocked(attr.Size) } - d.metadataMu.Unlock() } // Preconditions: !d.isSynthetic() @@ -834,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { file p9file handleMuRLocked bool ) + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() d.handleMu.RLock() if !d.handle.file.isNil() { file = d.handle.file @@ -849,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { if err != nil { return err } - d.updateFromP9Attrs(attrMask, &attr) + d.updateFromP9AttrsLocked(attrMask, &attr) return nil } diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index ff81ea6e6..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -40,6 +40,8 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a92aed2c9..ec5506efc 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 1243143ea..d9394055d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP } if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) info.NumEntries++ } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index d5ca3ac56..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,6 +36,8 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..98ca7add0 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ea6ebd0e2..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 964ec8414..9856ab8c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -62,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -910,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -920,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -956,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -971,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -981,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -1014,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1025,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1035,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1050,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1066,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1082,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1093,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1104,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1112,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1127,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1138,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1149,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1163,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1171,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1183,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1194,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1203,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1214,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1225,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1236,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1247,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1259,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1271,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1283,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1295,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1309,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1344,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1356,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1368,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1379,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1391,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1400,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1411,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1419,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1444,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIPv6(t, name) @@ -1453,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1466,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1482,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1496,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1507,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1532,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1543,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIP(t, name) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index d65a89316..a9025b0ec 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fcd7f9d7f..d112757fb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cca5e70f1..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -35,5 +35,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..0482d33cf 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index ff2149250..05c16fcfe 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..4a9b04fd0 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -99,5 +99,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..414fce8e3 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -29,6 +29,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 0c740335b..64696b438 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -72,5 +72,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..8096a8f9c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 945a364a7..63ab11f8c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -15,12 +15,15 @@ package vfs2 import ( + "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Move data. var ( - n int64 - err error - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { // If both input and output are pipes, delegate to the pipe - // implementation. Otherwise, exactly one end is a pipe, which we - // ensure is consistently ordered after the non-pipe FD's locks by - // passing the pipe FD as usermem.IO to the non-pipe end. + // implementation. Otherwise, exactly one end is a pipe, which + // we ensure is consistently ordered after the non-pipe FD's + // locks by passing the pipe FD as usermem.IO to the non-pipe + // end. switch { case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) @@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } else { n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } + default: + panic("not possible") } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } - - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the splice operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. - } - if err = t.Block(inCh); err != nil { - break - } - } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + if err = dw.waitForBoth(t); err != nil { + break } } @@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Copy data. var ( - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { - n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 { - return uintptr(n), nil, nil + n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + if err = dw.waitForBoth(t); err != nil { + break + } + } + if n == 0 { + return 0, nil, err + } + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Sendfile implements linux system call sendfile(2). +func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + outFD := args[0].Int() + inFD := args[1].Int() + offsetAddr := args[2].Pointer() + count := int64(args[3].SizeT()) + + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + if !inFile.IsReadable() { + return 0, nil, syserror.EBADF + } + + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + if !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // Verify that the outFile Append flag is not set. + if outFile.StatusFlags()&linux.O_APPEND != 0 { + return 0, nil, syserror.EINVAL + } + + // Verify that inFile is a regular file or block device. This is a + // requirement; the same check appears in Linux + // (fs/splice.c:splice_direct_to_actor). + if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil { + return 0, nil, err + } else if stat.Mask&linux.STATX_TYPE == 0 || + (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { + return 0, nil, syserror.EINVAL + } + + // Copy offset if it exists. + offset := int64(-1) + if offsetAddr != 0 { + if inFile.Options().DenyPRead { + return 0, nil, syserror.ESPIPE } - if err != syserror.ErrWouldBlock || nonBlock { + if _, err := t.CopyIn(offsetAddr, &offset); err != nil { return 0, nil, err } + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if offset+count < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Validate count. This must come after offset checks. + if count < 0 { + return 0, nil, syserror.EINVAL + } + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the tee operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Copy data. + var ( + n int64 + err error + ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + // Reading from input file should never block, since it is regular or + // block device. We only need to check if writing to the output file + // can block. + nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 + if outIsPipe { + for n < count { + var spliceN int64 + if offset != -1 { + spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{}) + offset += spliceN + } else { + spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } - if err := t.Block(inCh); err != nil { - return 0, nil, err + n += spliceN + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) + } + if err != nil { + break } } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. + } else { + // Read inFile to buffer, then write the contents to outFile. + buf := make([]byte, count) + for n < count { + var readN int64 + if offset != -1 { + readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) + offset += readN + } else { + readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + } + if readN == 0 && err == io.EOF { + // We reached the end of the file. Eat the + // error and exit the loop. + err = nil + break } - if err := t.Block(outCh); err != nil { - return 0, nil, err + n += readN + if err != nil { + break + } + + // Write all of the bytes that we read. This may need + // multiple write calls to complete. + wbuf := buf[:n] + for len(wbuf) > 0 { + var writeN int64 + writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) + wbuf = wbuf[writeN:] + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForOut(t) + } + if err != nil { + // We didn't complete the write. Only + // report the bytes that were actually + // written, and rewind the offset. + notWritten := int64(len(wbuf)) + n -= notWritten + if offset != -1 { + offset -= notWritten + } + break + } + } + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) } + if err != nil { + break + } + } + } + + if offsetAddr != 0 { + // Copy out the new offset. + if _, err := t.CopyOut(offsetAddr, offset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not +// thread-safe, and does not take a reference on the vfs.FileDescriptions. +// +// Users must call destroy() when finished. +type dualWaiter struct { + inFile *vfs.FileDescription + outFile *vfs.FileDescription + + inW waiter.Entry + inCh chan struct{} + outW waiter.Entry + outCh chan struct{} +} + +// waitForBoth waits for both dw.inFile and dw.outFile to be ready. +func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { + if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if dw.inCh == nil { + dw.inW, dw.inCh = waiter.NewChannelEntry(nil) + dw.inFile.EventRegister(&dw.inW, eventMaskRead) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.inCh); err != nil { + return err + } + } + return dw.waitForOut(t) +} + +// waitForOut waits for dw.outfile to be read. +func (dw *dualWaiter) waitForOut(t *kernel.Task) error { + if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready now. Try again before blocking. + return nil } + if err := t.Block(dw.outCh); err != nil { + return err + } + } + return nil +} + +// destroy cleans up resources help by dw. No more calls to wait* can occur +// after destroy is called. +func (dw *dualWaiter) destroy() { + if dw.inCh != nil { + dw.inFile.EventUnregister(&dw.inW) + dw.inCh = nil + } + if dw.outCh != nil { + dw.outFile.EventUnregister(&dw.outW) + dw.outCh = nil } + dw.inFile = nil + dw.outFile = nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 8f497ecc7..1b2cfad7d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -44,7 +44,7 @@ func Override() { s.Table[23] = syscalls.Supported("select", Select) s.Table[32] = syscalls.Supported("dup", Dup) s.Table[33] = syscalls.Supported("dup2", Dup2) - delete(s.Table, 40) // sendfile + s.Table[40] = syscalls.Supported("sendfile", Sendfile) s.Table[41] = syscalls.Supported("socket", Socket) s.Table[42] = syscalls.Supported("connect", Connect) s.Table[43] = syscalls.Supported("accept", Accept) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index d39baf620..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -49,7 +49,8 @@ const ( type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) @@ -113,13 +114,11 @@ type conn struct { // update the state of tcb. It is immutable. tcbHook Hook - // mu protects tcb. + // mu protects all mutable state. mu sync.Mutex `state:"nosave"` - // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB - // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. lastUsed time.Time `state:".(unixTime)"` @@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool { return now.Sub(cn.lastUsed) > defaultTimeout } +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} + // ConnTrack tracks all connections created for NAT rules. Most users are -// expected to only call handlePacket and createConnFor. +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. // // ConnTrack keeps all connections in a slice of buckets, each of which holds a // linked list of tuples. This gives us some desirable properties: @@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } -// createConnFor creates a new conn for pkt. -func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg manip = manipDstOutput } conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { // Lock the buckets in the correct order. - tupleBucket := ct.bucket(tid) - replyBucket := ct.bucket(replyTID) + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) ct.mu.RLock() defer ct.mu.RUnlock() if tupleBucket < replyBucket { @@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg ct.buckets[tupleBucket].mu.Lock() } - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { ct.buckets[replyBucket].mu.Unlock() } - - return conn } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. // TODO(gvisor.dev/issue/170): Change address for Prerouting hook. func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d } // handlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. if conn == nil { - // Connection not found for the packet or the packet is invalid. - return + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return false } switch hook { @@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else if hook == conn.tcbHook { - conn.tcb.UpdateStateOutbound(tcpHeader) - } else { - conn.tcb.UpdateStateInbound(tcpHeader) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return + } + + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return + } + + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { + return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) } // bucket gets the conntrack bucket for a tupleID. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 5f647c5fe..ca1dda695 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -245,13 +245,18 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. it.mu.RLock() defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { @@ -281,6 +286,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d43f60c67..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index c2e9fd29f..aefe0e2b2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -456,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } |