diff options
-rw-r--r-- | pkg/sentry/fs/context.go | 2 | ||||
-rw-r--r-- | pkg/sentry/inet/BUILD | 2 | ||||
-rw-r--r-- | pkg/sentry/inet/context.go | 35 | ||||
-rw-r--r-- | pkg/sentry/kernel/task.go | 3 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 25 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/route/protocol.go | 21 |
6 files changed, 66 insertions, 22 deletions
diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go index b521bce75..da46ad77f 100644 --- a/pkg/sentry/fs/context.go +++ b/pkg/sentry/fs/context.go @@ -20,7 +20,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" ) -// contextID is the kernel package's type for context.Context.Value keys. +// contextID is the fs package's type for context.Context.Value keys. type contextID int const ( diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 207cdb692..1150ced57 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -17,12 +17,14 @@ go_stateify( go_library( name = "inet", srcs = [ + "context.go", "inet.go", "inet_state.go", "test_stack.go", ], importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/inet", deps = [ + "//pkg/sentry/context", "//pkg/state", ], ) diff --git a/pkg/sentry/inet/context.go b/pkg/sentry/inet/context.go new file mode 100644 index 000000000..370381f41 --- /dev/null +++ b/pkg/sentry/inet/context.go @@ -0,0 +1,35 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inet + +import ( + "gvisor.googlesource.com/gvisor/pkg/sentry/context" +) + +// contextID is the inet package's type for context.Context.Value keys. +type contextID int + +const ( + // CtxStack is a Context.Value key for a network stack. + CtxStack contextID = iota +) + +// StackFromContext returns the network stack associated with ctx. +func StackFromContext(ctx context.Context) Stack { + if v := ctx.Value(CtxStack); v != nil { + return v.(Stack) + } + return nil +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 3d2e035e9..490f795c2 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -22,6 +22,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/bpf" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" + "gvisor.googlesource.com/gvisor/pkg/sentry/inet" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/futex" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/sched" @@ -560,6 +561,8 @@ func (t *Task) Value(key interface{}) interface{} { return t.creds case fs.CtxRoot: return t.FSContext().RootDirectory() + case inet.CtxStack: + return t.NetworkContext() case ktime.CtxRealtimeClock: return t.k.RealtimeClock() case limits.CtxLimits: diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index a45dcd551..3e4887e16 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -33,6 +33,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/binary" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -102,7 +103,6 @@ type SocketOperations struct { *waiter.Queue family int - stack inet.Stack Endpoint tcpip.Endpoint skType unix.SockType @@ -119,7 +119,6 @@ func New(t *kernel.Task, family int, skType unix.SockType, queue *waiter.Queue, return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true}, &SocketOperations{ Queue: queue, family: family, - stack: t.NetworkContext(), Endpoint: endpoint, skType: skType, }) @@ -1042,7 +1041,12 @@ func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, ar ) // Find the relevant device. - for index, iface = range s.stack.Interfaces() { + stack := inet.StackFromContext(ctx) + if stack == nil { + log.Warningf("Couldn't find a network stack.") + return syserr.ErrInvalidArgument + } + for index, iface = range stack.Interfaces() { if iface.Name == ifr.Name() { found = true break @@ -1074,7 +1078,7 @@ func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, ar case syscall.SIOCGIFADDR: // Copy the IPv4 address out. - for _, addr := range s.stack.InterfaceAddrs()[index] { + for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -1109,7 +1113,7 @@ func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, ar case syscall.SIOCGIFNETMASK: // Gets the network mask of a device. - for _, addr := range s.stack.InterfaceAddrs()[index] { + for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -1189,15 +1193,20 @@ func (s *SocketOperations) ifconfIoctl(ctx context.Context, io usermem.IO, ifc * // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. + stack := inet.StackFromContext(ctx) + if stack == nil { + log.Warningf("Couldn't find a network stack.") + return syserr.ErrInvalidArgument.ToError() + } if ifc.Ptr == 0 { - ifc.Len = int32(len(s.stack.Interfaces())) * int32(linux.SizeOfIFReq) + ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq) return nil } max := ifc.Len ifc.Len = 0 - for key, ifaceAddrs := range s.stack.InterfaceAddrs() { - iface := s.stack.Interfaces()[key] + for key, ifaceAddrs := range stack.InterfaceAddrs() { + iface := stack.Interfaces()[key] for _, ifaceAddr := range ifaceAddrs { // Don't write past the end of the buffer. if ifc.Len+int32(linux.SizeOfIFReq) > max { diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index d611519d4..e8030c518 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -43,20 +43,13 @@ func typeKind(typ uint16) commandKind { } // Protocol implements netlink.Protocol. -type Protocol struct { - // stack is the network stack that this provider describes. - // - // May be nil. - stack inet.Stack -} +type Protocol struct{} var _ netlink.Protocol = (*Protocol)(nil) // NewProtocol creates a NETLINK_ROUTE netlink.Protocol. func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) { - return &Protocol{ - stack: t.NetworkContext(), - }, nil + return &Protocol{}, nil } // Protocol implements netlink.Protocol.Protocol. @@ -83,12 +76,13 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader // We always send back an NLMSG_DONE. ms.Multi = true - if p.stack == nil { + stack := inet.StackFromContext(ctx) + if stack == nil { // No network devices. return nil } - for id, i := range p.stack.Interfaces() { + for id, i := range stack.Interfaces() { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.RTM_NEWLINK, }) @@ -124,12 +118,13 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader // We always send back an NLMSG_DONE. ms.Multi = true - if p.stack == nil { + stack := inet.StackFromContext(ctx) + if stack == nil { // No network devices. return nil } - for id, as := range p.stack.InterfaceAddrs() { + for id, as := range stack.InterfaceAddrs() { for _, a := range as { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.RTM_NEWADDR, |