diff options
Diffstat (limited to 'pkg/sentry/socket')
28 files changed, 930 insertions, 718 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 8b439a078..70ccf77a7 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -68,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { for _, fd := range fds { file := t.GetFile(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -100,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFiles) Release() { +func (fs *RightsFiles) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -115,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index fd08179be..d9621968c 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -46,7 +46,7 @@ func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { for _, fd := range fds { file := t.GetFileVFS2(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -78,9 +78,9 @@ func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFilesVFS2) Release() { +func (fs *RightsFilesVFS2) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -93,7 +93,7 @@ func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index e82d6cd1e..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -26,6 +26,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostfd", "//pkg/sentry/inet", @@ -39,6 +40,8 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index c11e82c10..242e6bf76 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -98,12 +100,12 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc return nil, syserr.FromError(err) } dirent := socket.NewDirent(ctx, socketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) syscall.Close(s.fd) } @@ -267,7 +269,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, syscall.Close(fd) return 0, nil, 0, err } - defer f.DecRef() + defer f.DecRef(t) kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, @@ -279,7 +281,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, syscall.Close(fd) return 0, nil, 0, err } - defer f.DecRef() + defer f.DecRef(t) kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, @@ -319,12 +321,12 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } - // Whitelist options and constrain option length. + // Only allow known and safe options. optlen := getSockOptLen(t, level, name) switch level { case linux.SOL_IP: @@ -364,12 +366,13 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { - // Whitelist options and constrain option length. + // Only allow known and safe options. optlen := setSockOptLen(t, level, name) switch level { case linux.SOL_IP: @@ -415,7 +418,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Whitelist flags. + // Only allow known and safe flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the @@ -537,7 +540,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // SendMsg implements socket.Socket.SendMsg. func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { - // Whitelist flags. + // Only allow known and safe flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument } @@ -708,6 +711,6 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) - socket.RegisterProviderVFS2(family, &socketProviderVFS2{}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{family}) } } diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 677743113..8a1d52ebf 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -35,6 +36,7 @@ import ( type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD // We store metadata for hostinet sockets internally. Technically, we should // access metadata (e.g. through stat, chmod) on the host for correctness, @@ -59,6 +61,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in fd: fd, }, } + s.LockFD.Init(&vfs.FileLocks{}) if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -68,6 +71,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in DenyPWrite: true, UseDentryMetadata: true, }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) return nil, syserr.FromError(err) } return vfsfd, nil @@ -93,7 +97,12 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// PRead implements vfs.FileDescriptionImpl. +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + +// PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE } @@ -131,6 +140,16 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return int64(n), err } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + type socketProviderVFS2 struct { family int } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index a48082631..fda3dcb35 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -53,6 +53,7 @@ type Stack struct { interfaceAddrs map[int32][]inet.InterfaceAddr routes []inet.Route supportsIPv6 bool + tcpRecovery inet.TCPLossRecovery tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool @@ -350,6 +351,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + return s.tcpRecovery, nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserror.EACCES +} + // getLine reads one line from proc file, with specified prefix. // The last argument, withHeader, specifies if it contains line header. func getLine(f *os.File, prefix string, withHeader bool) string { diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 789bb94c8..e91b0624c 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -32,28 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// errorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const errorTargetName = "ERROR" - -// redirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const redirectTargetName = "REDIRECT" - -// Metadata is used to verify that we are correctly serializing and -// deserializing iptables into structs consumable by the iptables tool. We save -// a metadata struct when the tables are written, and when they are read out we -// verify that certain fields are the same. -// -// metadata is used by this serialization/deserializing code, not netstack. -type metadata struct { - HookEntry [linux.NF_INET_NUMHOOKS]uint32 - Underflow [linux.NF_INET_NUMHOOKS]uint32 - NumEntries uint32 - Size uint32 -} - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. @@ -64,6 +42,8 @@ const enableLogging = false var emptyFilter = stack.IPHeaderFilter{ Dst: "\x00\x00\x00\x00", DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", } // nflog logs messages related to the writing and reading of iptables. @@ -77,33 +57,17 @@ func nflog(format string, args ...interface{}) { func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo - if _, err := t.CopyIn(outPtr, &info); err != nil { + if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, info.Name) + _, info, err := convertNetstackToBinary(stack, info.Name) if err != nil { - nflog("%v", err) + nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } - // Get the hooks that apply to this table. - info.ValidHooks = table.ValidHooks() - - // Grab the metadata struct, which is used to store information (e.g. - // the number of entries) that applies to the user's encoding of - // iptables, but not netstack's. - metadata := table.Metadata().(metadata) - - // Set values from metadata. - info.HookEntry = metadata.HookEntry - info.Underflow = metadata.Underflow - info.NumEntries = metadata.NumEntries - info.Size = metadata.Size - nflog("returning info: %+v", info) - return info, nil } @@ -111,28 +75,18 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries - if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + if _, err := userEntries.CopyIn(t, outPtr); err != nil { nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, userEntries.Name) - if err != nil { - nflog("%v", err) - return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument - } - // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table) + entries, _, err := convertNetstackToBinary(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if meta != table.Metadata().(metadata) { - panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta)) - } if binary.Size(entries) > uintptr(outLen) { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -141,48 +95,26 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] - if !ok { - return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) - } - return table, nil -} - -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) - } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) -} - // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { - // Return values. +func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + table, ok := stack.IPTables().GetTable(tablename.String()) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + var entries linux.KernelIPTGetEntries - var meta metadata + var info linux.IPTGetinfo + info.ValidHooks = table.ValidHooks() // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename) + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - copy(entries.Name[:], tablename) + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], tablename[:]) for ruleIdx, rule := range table.Rules { nflog("convert to binary: current offset: %d", entries.Size) @@ -191,20 +123,20 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI for hook, hookRuleIdx := range table.BuiltinChains { if hookRuleIdx == ruleIdx { nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - meta.HookEntry[hook] = entries.Size + info.HookEntry[hook] = entries.Size } } // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - meta.Underflow[underflow] = entries.Size + info.Underflow[underflow] = entries.Size } } // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -212,15 +144,20 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -232,8 +169,8 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -242,142 +179,18 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ - } - - nflog("convert to binary: finished with an marshalled size of %d", meta.Size) - meta.Size = entries.Size - return entries, meta, nil -} - -func marshalTarget(target stack.Target) []byte { - switch tg := target.(type) { - case stack.AcceptTarget: - return marshalStandardTarget(stack.RuleAccept) - case stack.DropTarget: - return marshalStandardTarget(stack.RuleDrop) - case stack.ErrorTarget: - return marshalErrorTarget(errorTargetName) - case stack.UserChainTarget: - return marshalErrorTarget(tg.Name) - case stack.ReturnTarget: - return marshalStandardTarget(stack.RuleReturn) - case stack.RedirectTarget: - return marshalRedirectTarget(tg) - case JumpTarget: - return marshalJumpTarget(tg) - default: - panic(fmt.Errorf("unknown target of type %T", target)) - } -} - -func marshalStandardTarget(verdict stack.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target") - - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - Verdict: translateFromStandardVerdict(verdict), + info.NumEntries++ } - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalErrorTarget(errorName string) []byte { - // This is an error target named error - target := linux.XTErrorTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTErrorTarget, - }, - } - copy(target.Name[:], errorName) - copy(target.Target.Name[:], errorTargetName) - - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalRedirectTarget(rt stack.RedirectTarget) []byte { - // This is a redirect target named redirect - target := linux.XTRedirectTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTRedirectTarget, - }, - } - copy(target.Target.Name[:], redirectTargetName) - - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) - target.NfRange.RangeSize = 1 - if rt.RangeProtoSpecified { - target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED - } - // Convert port from little endian to big endian. - port := make([]byte, 2) - binary.LittleEndian.PutUint16(port, rt.MinPort) - target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) - binary.LittleEndian.PutUint16(port, rt.MaxPort) - target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalJumpTarget(jt JumpTarget) []byte { - nflog("convert to binary: marshalling jump target") - - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - // Verdict is overloaded by the ABI. When positive, it holds - // the jump offset from the start of the table. - Verdict: int32(jt.Offset), - } - - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -// translateFromStandardVerdict translates verdicts the same way as the iptables -// tool. -func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { - switch verdict { - case stack.RuleAccept: - return -linux.NF_ACCEPT - 1 - case stack.RuleDrop: - return -linux.NF_DROP - 1 - case stack.RuleReturn: - return linux.NF_RETURN - default: - // TODO(gvisor.dev/issue/170): Support Jump. - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) - } -} - -// translateToStandardTarget translates from the value in a -// linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32) (stack.Target, error) { - // TODO(gvisor.dev/issue/170): Support other verdicts. - switch val { - case -linux.NF_ACCEPT - 1: - return stack.AcceptTarget{}, nil - case -linux.NF_DROP - 1: - return stack.DropTarget{}, nil - case -linux.NF_QUEUE - 1: - return nil, errors.New("unsupported iptables verdict QUEUE") - case linux.NF_RETURN: - return stack.ReturnTarget{}, nil - default: - return nil, fmt.Errorf("unknown iptables verdict %d", val) - } + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + info.Size = entries.Size + return entries, info, nil } // SetEntries sets iptables rules for a single table. See @@ -396,10 +209,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -485,6 +298,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<<hook) != 0 { hk := hookFromLinux(hook) + table.BuiltinChains[hk] = stack.HookUnset + table.Underflows[hk] = stack.HookUnset for offset, ruleIdx := range offsets { if offset == replace.HookEntry[hook] { table.BuiltinChains[hk] = ruleIdx @@ -510,8 +325,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Add the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(stack.UserChainTarget) - if !ok { + if _, ok := rule.Target.(stack.UserChainTarget); !ok { continue } @@ -527,7 +341,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { nflog("user chain's first node must have no matchers") return syserr.ErrInvalidArgument } - table.UserChains[target.Name] = ruleIdx + 1 } // Set each jump to point to the appropriate rule. Right now they hold byte @@ -553,7 +366,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if ruleIdx == stack.HookUnset { + continue + } if !isUnconditionalAccept(table.Rules[ruleIdx]) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument @@ -566,17 +382,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() - table.SetMetadata(metadata{ - HookEntry: replace.HookEntry, - Underflow: replace.Underflow, - NumEntries: replace.NumEntries, - Size: replace.Size, - }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) - - return nil + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -623,113 +429,6 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return matchers, nil } -// parseTarget parses a target from optVal. optVal should contain only the -// target. -func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { - nflog("set entries: parsing target of size %d", len(optVal)) - if len(optVal) < linux.SizeOfXTEntryTarget { - return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) - } - var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &target) - switch target.Name.String() { - case "": - // Standard target. - if len(optVal) != linux.SizeOfXTStandardTarget { - return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) - } - var standardTarget linux.XTStandardTarget - buf = optVal[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - - if standardTarget.Verdict < 0 { - // A Verdict < 0 indicates a non-jump verdict. - return translateToStandardTarget(standardTarget.Verdict) - } - // A verdict >= 0 indicates a jump. - return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil - - case errorTargetName: - // Error target. - if len(optVal) != linux.SizeOfXTErrorTarget { - return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) - } - var errorTarget linux.XTErrorTarget - buf = optVal[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) - - // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named errorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. - // * To mark the start of a user defined chain. These - // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case errorTargetName: - nflog("set entries: error target") - return stack.ErrorTarget{}, nil - default: - // User defined chain. - nflog("set entries: user-defined target %q", name) - return stack.UserChainTarget{Name: name}, nil - } - - case redirectTargetName: - // Redirect target. - if len(optVal) < linux.SizeOfXTRedirectTarget { - return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) - } - - if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - var redirectTarget linux.XTRedirectTarget - buf = optVal[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - - // Copy linux.XTRedirectTarget to stack.RedirectTarget. - var target stack.RedirectTarget - nfRange := redirectTarget.NfRange - - // RangeSize should be 1. - if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - // TODO(gvisor.dev/issue/170): Check if the flags are valid. - // Also check if we need to map ports or IP. - // For now, redirect target only supports destination port change. - // Port range and IP range are not supported yet. - if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - target.RangeProtoSpecified = true - - target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) - target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) - - // TODO(gvisor.dev/issue/170): Port range is not supported yet. - if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - // Convert port from big endian to little endian. - port := make([]byte, 2) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) - target.MinPort = binary.LittleEndian.Uint16(port) - - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) - target.MaxPort = binary.LittleEndian.Uint16(port) - return target, nil - } - - // Unknown target. - return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) -} - func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if containsUnsupportedFields(iptip) { return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) @@ -737,6 +436,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) if n == -1 { @@ -755,6 +457,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, OutputInterface: ifname, OutputInterfaceMask: ifnameMask, OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, @@ -765,15 +470,13 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // The following features are supported: // - Protocol // - Dst and DstMask + // - Src and SrcMask // - The inverse destination IP check flag // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.Src != emptyInetAddr || - iptip.SrcMask != emptyInetAddr || - iptip.InputInterface != emptyInterface || + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 3863293c7..1b4e0ad79 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 84abe8d29..8ebdaff18 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,10 +15,257 @@ package netfilter import ( + "errors" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) +// errorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const errorTargetName = "ERROR" + +// redirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const redirectTargetName = "REDIRECT" + +func marshalTarget(target stack.Target) []byte { + switch tg := target.(type) { + case stack.AcceptTarget: + return marshalStandardTarget(stack.RuleAccept) + case stack.DropTarget: + return marshalStandardTarget(stack.RuleDrop) + case stack.ErrorTarget: + return marshalErrorTarget(errorTargetName) + case stack.UserChainTarget: + return marshalErrorTarget(tg.Name) + case stack.ReturnTarget: + return marshalStandardTarget(stack.RuleReturn) + case stack.RedirectTarget: + return marshalRedirectTarget(tg) + case JumpTarget: + return marshalJumpTarget(tg) + default: + panic(fmt.Errorf("unknown target of type %T", target)) + } +} + +func marshalStandardTarget(verdict stack.RuleVerdict) []byte { + nflog("convert to binary: marshalling standard target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + Verdict: translateFromStandardVerdict(verdict), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalErrorTarget(errorName string) []byte { + // This is an error target named error + target := linux.XTErrorTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTErrorTarget, + }, + } + copy(target.Name[:], errorName) + copy(target.Target.Name[:], errorTargetName) + + ret := make([]byte, 0, linux.SizeOfXTErrorTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalRedirectTarget(rt stack.RedirectTarget) []byte { + // This is a redirect target named redirect + target := linux.XTRedirectTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTRedirectTarget, + }, + } + copy(target.Target.Name[:], redirectTargetName) + + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + target.NfRange.RangeSize = 1 + if rt.RangeProtoSpecified { + target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + } + // Convert port from little endian to big endian. + port := make([]byte, 2) + binary.LittleEndian.PutUint16(port, rt.MinPort) + target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) + binary.LittleEndian.PutUint16(port, rt.MaxPort) + target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalJumpTarget(jt JumpTarget) []byte { + nflog("convert to binary: marshalling jump target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + // Verdict is overloaded by the ABI. When positive, it holds + // the jump offset from the start of the table. + Verdict: int32(jt.Offset), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +// translateFromStandardVerdict translates verdicts the same way as the iptables +// tool. +func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { + switch verdict { + case stack.RuleAccept: + return -linux.NF_ACCEPT - 1 + case stack.RuleDrop: + return -linux.NF_DROP - 1 + case stack.RuleReturn: + return linux.NF_RETURN + default: + // TODO(gvisor.dev/issue/170): Support Jump. + panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) + } +} + +// translateToStandardTarget translates from the value in a +// linux.XTStandardTarget to an stack.Verdict. +func translateToStandardTarget(val int32) (stack.Target, error) { + // TODO(gvisor.dev/issue/170): Support other verdicts. + switch val { + case -linux.NF_ACCEPT - 1: + return stack.AcceptTarget{}, nil + case -linux.NF_DROP - 1: + return stack.DropTarget{}, nil + case -linux.NF_QUEUE - 1: + return nil, errors.New("unsupported iptables verdict QUEUE") + case linux.NF_RETURN: + return stack.ReturnTarget{}, nil + default: + return nil, fmt.Errorf("unknown iptables verdict %d", val) + } +} + +// parseTarget parses a target from optVal. optVal should contain only the +// target. +func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { + nflog("set entries: parsing target of size %d", len(optVal)) + if len(optVal) < linux.SizeOfXTEntryTarget { + return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) + } + var target linux.XTEntryTarget + buf := optVal[:linux.SizeOfXTEntryTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + switch target.Name.String() { + case "": + // Standard target. + if len(optVal) != linux.SizeOfXTStandardTarget { + return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) + } + var standardTarget linux.XTStandardTarget + buf = optVal[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict) + } + // A verdict >= 0 indicates a jump. + return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil + + case errorTargetName: + // Error target. + if len(optVal) != linux.SizeOfXTErrorTarget { + return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) + } + var errorTarget linux.XTErrorTarget + buf = optVal[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named errorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch name := errorTarget.Name.String(); name { + case errorTargetName: + nflog("set entries: error target") + return stack.ErrorTarget{}, nil + default: + // User defined chain. + nflog("set entries: user-defined target %q", name) + return stack.UserChainTarget{Name: name}, nil + } + + case redirectTargetName: + // Redirect target. + if len(optVal) < linux.SizeOfXTRedirectTarget { + return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) + } + + if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + var redirectTarget linux.XTRedirectTarget + buf = optVal[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + var target stack.RedirectTarget + nfRange := redirectTarget.NfRange + + // RangeSize should be 1. + if nfRange.RangeSize != 1 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + target.RangeProtoSpecified = true + + target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // Convert port from big endian to little endian. + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) + target.MinPort = binary.LittleEndian.Uint16(port) + + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) + target.MaxPort = binary.LittleEndian.Uint16(port) + return target, nil + } + + // Unknown target. + return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String()) +} + // JumpTarget implements stack.Target. type JumpTarget struct { // Offset is the byte offset of the rule to jump to. It is used for @@ -30,6 +277,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 57a1e1c12..0bfd6c1f4 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,8 +96,8 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader().View()) if netHeader.TransportProtocol() != header.TCPProtocolNumber { return false, false @@ -111,36 +111,10 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa return false, false } - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the stack.Check codepath as matchers are - // added. - var tcpHeader header.TCP - if pkt.TransportHeader != nil { - tcpHeader = header.TCP(pkt.TransportHeader) - } else { - var length int - if hook == stack.Prerouting { - // The network header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // There's no valid TCP header here, so we hotdrop the - // packet. - return false, true - } - h := header.IPv4(hdr) - pkt.NetworkHeader = hdr - length = int(h.HeaderLength()) - } - // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - // There's no valid TCP header here, so we hotdrop the - // packet. - return false, true - } - tcpHeader = header.TCP(hdr[length:]) + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { + // There's no valid TCP header here, so we drop the packet immediately. + return false, true } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index cfa9e621d..7ed05461d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,8 +93,8 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. @@ -110,36 +110,10 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa return false, false } - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the stack.Check codepath as matchers are - // added. - var udpHeader header.UDP - if pkt.TransportHeader != nil { - udpHeader = header.UDP(pkt.TransportHeader) - } else { - var length int - if hook == stack.Prerouting { - // The network header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // There's no valid UDP header here, so we hotdrop the - // packet. - return false, true - } - h := header.IPv4(hdr) - pkt.NetworkHeader = hdr - length = int(h.HeaderLength()) - } - // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - // There's no valid UDP header here, so we hotdrop the - // packet. - return false, true - } - udpHeader = header.UDP(hdr[length:]) + udpHeader := header.UDP(pkt.TransportHeader().View()) + if len(udpHeader) < header.UDPMinimumSize { + // There's no valid UDP header here, so we drop the packet immediately. + return false, true } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 7212d8644..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -35,6 +36,8 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 0d45e5053..31e374833 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -97,7 +97,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int } d := socket.NewDirent(t, netlinkSocketDevice) - defer d.DecRef() + defer d.DecRef(t) return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..68a9b9a96 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -138,14 +140,14 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } @@ -162,9 +164,9 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { - s.connection.Release() - s.ep.Close() +func (s *socketOpsCommon) Release(ctx context.Context) { + s.connection.Release(ctx) + s.ep.Close(ctx) if s.bound { s.ports.Release(s.protocol.Protocol(), s.portID) @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -617,7 +621,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -644,7 +648,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index b854bf990..a38d25da9 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -40,6 +41,7 @@ type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD socketOpsCommon } @@ -55,18 +57,18 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } - return &SocketVFS2{ + fd := &SocketVFS2{ socketOpsCommon: socketOpsCommon{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, @@ -75,7 +77,9 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV connection: connection, sendBufferSize: defaultSendBufferSize, }, - }, nil + } + fd.LockFD.Init(&vfs.FileLocks{}) + return fd, nil } // Readiness implements waiter.Waitable.Readiness. @@ -136,3 +140,13 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 333e0042e..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -50,5 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 60df51dae..e4846bc0b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -33,6 +34,7 @@ import ( "syscall" "time" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" @@ -60,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -190,6 +194,8 @@ var Metrics = tcpip.Stats{ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), + ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -294,8 +300,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -323,7 +330,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue } dirent := socket.NewDirent(t, netstackDevice) - defer dirent.DecRef() + defer dirent.DecRef(t) return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ socketOpsCommon: socketOpsCommon{ Queue: queue, @@ -416,7 +423,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -444,8 +451,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -459,7 +479,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(context.Context) { s.Endpoint.Close() } @@ -719,6 +739,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool defer s.EventUnregister(&e) if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { + // TCP unlike UDP returns EADDRNOTAVAIL when it can't + // find an available local ephemeral port. + if err == tcpip.ErrNoPortAvailable { + return syserr.ErrAddressNotAvailable + } + } + return syserr.TranslateNetstackError(err) } @@ -826,7 +854,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -884,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -894,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -930,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -945,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -955,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -988,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -999,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1009,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1024,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1040,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1056,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1067,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1078,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1086,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1101,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1112,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1123,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1137,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1145,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1157,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil + + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1166,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1177,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1188,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1199,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1210,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1222,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1234,8 +1303,20 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil - return int32(time.Duration(v) / time.Second), nil + case linux.TCP_KEEPCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1246,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1260,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1295,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1307,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1319,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1330,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1342,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1351,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1362,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1370,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1395,7 +1486,13 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + // TODO(gvisor.dev/issue/170): ip6tables. + return nil, syserr.ErrInvalidArgument default: emitUnimplementedEventIPv6(t, name) @@ -1404,7 +1501,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1417,11 +1514,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1433,7 +1531,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1447,7 +1546,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1458,21 +1557,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1483,7 +1587,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1494,7 +1600,22 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil default: emitUnimplementedEventIP(t, name) @@ -1698,6 +1819,14 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + case linux.SO_NO_CHECK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { return syserr.ErrInvalidArgument @@ -1712,6 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -1777,6 +1911,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_KEEPCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + if v < 1 || v > linux.MAX_TCP_KEEPCNT { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v))) + case linux.TCP_USER_TIMEOUT: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2060,13 +2205,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -2106,30 +2260,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { switch name { case linux.TCP_CONGESTION, linux.TCP_CORK, - linux.TCP_DEFER_ACCEPT, linux.TCP_FASTOPEN, linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_KEEPCNT, - linux.TCP_KEEPIDLE, - linux.TCP_KEEPINTVL, - linux.TCP_LINGER2, - linux.TCP_MAXSEG, linux.TCP_QUEUE_SEQ, - linux.TCP_QUICKACK, linux.TCP_REPAIR, linux.TCP_REPAIR_QUEUE, linux.TCP_REPAIR_WINDOW, linux.TCP_SAVED_SYN, linux.TCP_SAVE_SYN, - linux.TCP_SYNCNT, linux.TCP_THIN_DUPACK, linux.TCP_THIN_LINEAR_TIMEOUTS, linux.TCP_TIMESTAMP, - linux.TCP_ULP, - linux.TCP_USER_TIMEOUT, - linux.TCP_WINDOW_CLAMP: + linux.TCP_ULP: t.Kernel().EmitUnimplementedEvent(t) } @@ -2291,7 +2435,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -2397,6 +2541,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2452,6 +2613,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2686,11 +2852,16 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2698,9 +2869,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2719,9 +2888,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2730,52 +2898,49 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2788,9 +2953,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2804,9 +2968,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -2832,7 +2995,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2855,21 +3018,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2878,7 +3048,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2889,32 +3059,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2931,6 +3101,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument @@ -2940,7 +3118,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -2977,9 +3155,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index fcd8013c0..3335e7430 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -30,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -38,6 +41,7 @@ type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD socketOpsCommon } @@ -64,6 +68,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu protocol: protocol, }, } + s.LockFD.Init(&vfs.FileLocks{}) vfsfd := &s.vfsfd if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, @@ -164,7 +169,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil { return 0, nil, 0, syserr.FromError(err) @@ -197,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -207,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -243,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -258,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -315,3 +320,13 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..f0fe18684 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,10 +15,11 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } @@ -196,6 +207,20 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + var recovery tcp.Recovery + if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + return inet.TCPLossRecovery(recovery), nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { switch stats := stat.(type) { @@ -314,7 +339,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { udp.PacketsSent.Value(), // OutDatagrams. udp.ReceiveBufferErrors.Value(), // RcvbufErrors. 0, // Udp/SndbufErrors. - 0, // Udp/InCsumErrors. + udp.ChecksumErrors.Value(), // Udp/InCsumErrors. 0, // Udp/IgnoredMulti. } default: @@ -362,16 +387,10 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) -} - // Resume implements inet.Stack.Resume. func (s *Stack) Resume() { s.Stack.Resume() diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 6580bd6e9..04b259d27 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -45,8 +46,8 @@ type ControlMessages struct { } // Release releases Unix domain socket credentials and rights. -func (c *ControlMessages) Release() { - c.Unix.Release() +func (c *ControlMessages) Release(ctx context.Context) { + c.Unix.Release(ctx) } // Socket is an interface combining fs.FileOperations and SocketOps, @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error @@ -407,7 +408,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { linux.SO_MARK, linux.SO_MAX_PACING_RATE, linux.SO_NOFCS, - linux.SO_NO_CHECK, linux.SO_OOBINLINE, linux.SO_PASSCRED, linux.SO_PASSSEC, diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index de2cc4bdf..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", @@ -34,5 +35,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index ce5b94ee7..c67b602f0 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -211,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool { // The socket will be a fresh state after a call to close and may be reused. // That is, close may be used to "unbind" or "disconnect" the socket in error // paths. -func (e *connectionedEndpoint) Close() { +func (e *connectionedEndpoint) Close(ctx context.Context) { e.Lock() var c ConnectedEndpoint var r Receiver @@ -233,7 +233,7 @@ func (e *connectionedEndpoint) Close() { case e.Listening(): close(e.acceptedChan) for n := range e.acceptedChan { - n.Close() + n.Close(ctx) } e.acceptedChan = nil e.path = "" @@ -241,18 +241,18 @@ func (e *connectionedEndpoint) Close() { e.Unlock() if c != nil { c.CloseNotify() - c.Release() + c.Release(ctx) } if r != nil { r.CloseNotify() - r.Release() + r.Release(ctx) } } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { - return syserr.ErrConnectionRefused + return syserr.ErrWrongProtocolForSocket } // Check if ce is e to avoid a deadlock. @@ -340,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: // Busy; return ECONNREFUSED per spec. - ne.Close() + ne.Close(ctx) e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused @@ -476,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask // State implements socket.Socket.State. func (e *connectionedEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + if e.Connected() { return linux.SS_CONNECTED } diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 4b06d63ac..70ee8f9b8 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool { // Close puts the endpoint in a closed state and frees all resources associated // with it. -func (e *connectionlessEndpoint) Close() { +func (e *connectionlessEndpoint) Close(ctx context.Context) { e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) e.connected = nil } @@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() { e.Unlock() r.CloseNotify() - r.Release() + r.Release(ctx) } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. @@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C if err != nil { return 0, syserr.ErrInvalidEndpointState } - defer connected.Release() + defer connected.Release(ctx) e.Lock() - n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) } e.connected = connected e.Unlock() diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index d8f3ad63d..ef6043e19 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,6 +15,7 @@ package transport import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" @@ -57,10 +58,10 @@ func (q *queue) Close() { // Both the read and write queues must be notified after resetting: // q.ReaderQueue.Notify(waiter.EventIn) // q.WriterQueue.Notify(waiter.EventOut) -func (q *queue) Reset() { +func (q *queue) Reset(ctx context.Context) { q.mu.Lock() for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release() + cur.Release(ctx) } q.dataList.Reset() q.used = 0 @@ -68,8 +69,8 @@ func (q *queue) Reset() { } // DecRef implements RefCounter.DecRef with destructor q.Reset. -func (q *queue) DecRef() { - q.DecRefWithDestructor(q.Reset) +func (q *queue) DecRef(ctx context.Context) { + q.DecRefWithDestructor(ctx, q.Reset) // We don't need to notify after resetting because no one cares about // this queue after all references have been dropped. } @@ -111,7 +112,7 @@ func (q *queue) IsWritable() bool { // // If notify is true, ReaderQueue.Notify must be called: // q.ReaderQueue.Notify(waiter.EventIn) -func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { +func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { q.mu.Lock() if q.closed { @@ -124,7 +125,7 @@ func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress } if discardEmpty && l == 0 { q.mu.Unlock() - c.Release() + c.Release(ctx) return 0, false, nil } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2f1b127df..475d7177e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -37,7 +37,7 @@ type RightsControlMessage interface { Clone() RightsControlMessage // Release releases any resources owned by the RightsControlMessage. - Release() + Release(ctx context.Context) } // A CredentialsControlMessage is a control message containing Unix credentials. @@ -74,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages { } // Release releases both the credentials and the rights. -func (c *ControlMessages) Release() { +func (c *ControlMessages) Release(ctx context.Context) { if c.Rights != nil { - c.Rights.Release() + c.Rights.Release(ctx) } *c = ControlMessages{} } @@ -90,7 +90,7 @@ type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources // associated with it. - Close() + Close(ctx context.Context) // RecvMsg reads data and a control message from the endpoint. This method // does not block if there is no data pending. @@ -252,7 +252,7 @@ type BoundEndpoint interface { // Release releases any resources held by the BoundEndpoint. It must be // called before dropping all references to a BoundEndpoint returned by a // function. - Release() + Release(ctx context.Context) } // message represents a message passed over a Unix domain socket. @@ -281,8 +281,8 @@ func (m *message) Length() int64 { } // Release releases any resources held by the message. -func (m *message) Release() { - m.Control.Release() +func (m *message) Release(ctx context.Context) { + m.Control.Release(ctx) } // Peek returns a copy of the message. @@ -304,7 +304,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -333,7 +333,7 @@ type Receiver interface { // Release releases any resources owned by the Receiver. It should be // called before droping all references to a Receiver. - Release() + Release(ctx context.Context) } // queueReceiver implements Receiver for datagram sockets. @@ -344,7 +344,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -398,8 +398,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 { } // Release implements Receiver.Release. -func (q *queueReceiver) Release() { - q.readQueue.DecRef() +func (q *queueReceiver) Release(ctx context.Context) { + q.readQueue.DecRef(ctx) } // streamQueueReceiver implements Receiver for stream sockets. @@ -456,7 +456,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -502,7 +502,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, var cmTruncated bool if c.Rights != nil && numRights == 0 { - c.Rights.Release() + c.Rights.Release(ctx) c.Rights = nil cmTruncated = true } @@ -557,7 +557,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, // Consume rights. if numRights == 0 { cmTruncated = true - q.control.Rights.Release() + q.control.Rights.Release(ctx) } else { c.Rights = q.control.Rights haveRights = true @@ -582,7 +582,7 @@ type ConnectedEndpoint interface { // // syserr.ErrWouldBlock can be returned along with a partial write if // the caller should block to send the rest of the data. - Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) + Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) // SendNotify notifies the ConnectedEndpoint of a successful Send. This // must not be called while holding any endpoint locks. @@ -616,7 +616,7 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. - Release() + Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. @@ -654,7 +654,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) } // Send implements ConnectedEndpoint.Send. -func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { discardEmpty := false truncate := false if e.endpoint.Type() == linux.SOCK_STREAM { @@ -669,7 +669,7 @@ func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.Fu truncate = true } - return e.writeQueue.Enqueue(data, c, from, discardEmpty, truncate) + return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate) } // SendNotify implements ConnectedEndpoint.SendNotify. @@ -707,8 +707,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 { } // Release implements ConnectedEndpoint.Release. -func (e *connectedEndpoint) Release() { - e.writeQueue.DecRef() +func (e *connectedEndpoint) Release(ctx context.Context) { + e.writeQueue.DecRef(ctx) } // CloseUnread implements ConnectedEndpoint.CloseUnread. @@ -798,7 +798,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek) e.Unlock() if err != nil { return 0, 0, ControlMessages{}, false, err @@ -827,7 +827,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return 0, syserr.ErrAlreadyConnected } - n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -1001,6 +1001,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } // Release implements BoundEndpoint.Release. -func (*baseEndpoint) Release() { +func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 5b29e9d7f..2b8454edb 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -61,7 +62,7 @@ type SocketOperations struct { // New creates a new unix socket. func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true}) } @@ -96,17 +97,17 @@ type socketOpsCommon struct { } // DecRef implements RefCounter.DecRef. -func (s *socketOpsCommon) DecRef() { - s.DecRefWithDestructor(func() { - s.ep.Close() +func (s *socketOpsCommon) DecRef(ctx context.Context) { + s.DecRefWithDestructor(ctx, func(context.Context) { + s.ep.Close(ctx) }) } // Release implemements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(ctx context.Context) { // Release only decrements a reference on s because s may be referenced in // the abstract socket namespace. - s.DecRef() + s.DecRef(ctx) } func (s *socketOpsCommon) isPacket() bool { @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } @@ -233,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ns := New(t, ep, s.stype) - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -283,7 +284,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } @@ -293,7 +294,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { var name string cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) // Is there no slash at all? if !strings.Contains(p, "/") { @@ -301,7 +302,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { name = p } else { root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Find the last path component, we know that something follows // that final slash, otherwise extractPath() would have failed. lastSlash := strings.LastIndex(p, "/") @@ -317,7 +318,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // No path available. return syserr.ErrNoSuchFile } - defer d.DecRef() + defer d.DecRef(t) name = p[lastSlash+1:] } @@ -331,7 +332,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if err != nil { return syserr.ErrPortInUse } - childDir.DecRef() + childDir.DecRef(t) } return nil @@ -377,9 +378,9 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, FollowFinalSymlink: true, } ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) - root.DecRef() + root.DecRef(t) if relPath { - start.DecRef() + start.DecRef(t) } if e != nil { return nil, syserr.FromError(e) @@ -392,15 +393,15 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, cwd := t.FSContext().WorkingDirectory() remainingTraversals := uint(fs.DefaultTraversalLimit) d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals) - cwd.DecRef() - root.DecRef() + cwd.DecRef(t) + root.DecRef(t) if e != nil { return nil, syserr.FromError(e) } // Extract the endpoint if one is there. ep := d.Inode.BoundEndpoint(path) - d.DecRef() + d.DecRef(t) if ep == nil { // No socket! return nil, syserr.ErrConnectionRefused @@ -414,10 +415,21 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool if err != nil { return err } - defer ep.Release() + defer ep.Release(t) // Connect the server endpoint. - return s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep) + + if err == syserr.ErrWrongProtocolForSocket { + // Linux for abstract sockets returns ErrConnectionRefused + // instead of ErrWrongProtocolForSocket. + path, _ := extractPath(sockaddr) + if len(path) > 0 && path[0] == 0 { + err = syserr.ErrConnectionRefused + } + } + + return err } // Write implements fs.FileOperations.Write. @@ -448,15 +460,25 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b To: nil, } if len(to) > 0 { - ep, err := extractEndpoint(t, to) - if err != nil { - return 0, err - } - defer ep.Release() - w.To = ep + switch s.stype { + case linux.SOCK_SEQPACKET: + to = nil + case linux.SOCK_STREAM: + if s.State() == linux.SS_CONNECTED { + return 0, syserr.ErrAlreadyConnected + } + return 0, syserr.ErrNotSupported + default: + ep, err := extractEndpoint(t, to) + if err != nil { + return 0, err + } + defer ep.Release(t) + w.To = ep - if ep.Passcred() && w.Control.Credentials == nil { - w.Control.Credentials = control.MakeCreds(t) + if ep.Passcred() && w.Control.Credentials == nil { + w.Control.Credentials = control.MakeCreds(t) + } } } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 45e109361..dfa25241a 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -39,6 +41,7 @@ type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD socketOpsCommon } @@ -51,7 +54,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) - fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d) + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) if err != nil { return nil, syserr.FromError(err) } @@ -60,7 +63,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) // NewFileDescription creates and returns a socket file description // corresponding to the given mount and dentry. -func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) { +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) { // You can create AF_UNIX, SOCK_RAW sockets. They're the same as // SOCK_DGRAM and don't require CAP_NET_RAW. if stype == linux.SOCK_RAW { @@ -73,6 +76,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, @@ -86,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } @@ -132,7 +136,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) @@ -179,19 +183,19 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } } else { path := fspath.Parse(p) root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root relPath := !path.Absolute if relPath { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } pop := vfs.PathOperation{ Root: root, @@ -297,6 +301,16 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + // providerVFS2 is a unix domain socket provider for VFS2. type providerVFS2 struct{} @@ -319,7 +333,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) f, err := NewSockfsFile(t, ep, stype) if err != nil { - ep.Close() + ep.Close(t) return nil, err } return f, nil @@ -343,14 +357,14 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (* ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) s1, err := NewSockfsFile(t, ep1, stype) if err != nil { - ep1.Close() - ep2.Close() + ep1.Close(t) + ep2.Close(t) return nil, nil, err } s2, err := NewSockfsFile(t, ep2, stype) if err != nil { - s1.DecRef() - ep2.Close() + s1.DecRef(t) + ep2.Close(t) return nil, nil, err } |