summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go25
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go21
2 files changed, 25 insertions, 21 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index a45dcd551..3e4887e16 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -33,6 +33,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
@@ -102,7 +103,6 @@ type SocketOperations struct {
*waiter.Queue
family int
- stack inet.Stack
Endpoint tcpip.Endpoint
skType unix.SockType
@@ -119,7 +119,6 @@ func New(t *kernel.Task, family int, skType unix.SockType, queue *waiter.Queue,
return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true}, &SocketOperations{
Queue: queue,
family: family,
- stack: t.NetworkContext(),
Endpoint: endpoint,
skType: skType,
})
@@ -1042,7 +1041,12 @@ func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, ar
)
// Find the relevant device.
- for index, iface = range s.stack.Interfaces() {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ log.Warningf("Couldn't find a network stack.")
+ return syserr.ErrInvalidArgument
+ }
+ for index, iface = range stack.Interfaces() {
if iface.Name == ifr.Name() {
found = true
break
@@ -1074,7 +1078,7 @@ func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, ar
case syscall.SIOCGIFADDR:
// Copy the IPv4 address out.
- for _, addr := range s.stack.InterfaceAddrs()[index] {
+ for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -1109,7 +1113,7 @@ func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, ar
case syscall.SIOCGIFNETMASK:
// Gets the network mask of a device.
- for _, addr := range s.stack.InterfaceAddrs()[index] {
+ for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -1189,15 +1193,20 @@ func (s *SocketOperations) ifconfIoctl(ctx context.Context, io usermem.IO, ifc *
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ log.Warningf("Couldn't find a network stack.")
+ return syserr.ErrInvalidArgument.ToError()
+ }
if ifc.Ptr == 0 {
- ifc.Len = int32(len(s.stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
return nil
}
max := ifc.Len
ifc.Len = 0
- for key, ifaceAddrs := range s.stack.InterfaceAddrs() {
- iface := s.stack.Interfaces()[key]
+ for key, ifaceAddrs := range stack.InterfaceAddrs() {
+ iface := stack.Interfaces()[key]
for _, ifaceAddr := range ifaceAddrs {
// Don't write past the end of the buffer.
if ifc.Len+int32(linux.SizeOfIFReq) > max {
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index d611519d4..e8030c518 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -43,20 +43,13 @@ func typeKind(typ uint16) commandKind {
}
// Protocol implements netlink.Protocol.
-type Protocol struct {
- // stack is the network stack that this provider describes.
- //
- // May be nil.
- stack inet.Stack
-}
+type Protocol struct{}
var _ netlink.Protocol = (*Protocol)(nil)
// NewProtocol creates a NETLINK_ROUTE netlink.Protocol.
func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
- return &Protocol{
- stack: t.NetworkContext(),
- }, nil
+ return &Protocol{}, nil
}
// Protocol implements netlink.Protocol.Protocol.
@@ -83,12 +76,13 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader
// We always send back an NLMSG_DONE.
ms.Multi = true
- if p.stack == nil {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
// No network devices.
return nil
}
- for id, i := range p.stack.Interfaces() {
+ for id, i := range stack.Interfaces() {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.RTM_NEWLINK,
})
@@ -124,12 +118,13 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader
// We always send back an NLMSG_DONE.
ms.Multi = true
- if p.stack == nil {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
// No network devices.
return nil
}
- for id, as := range p.stack.InterfaceAddrs() {
+ for id, as := range stack.InterfaceAddrs() {
for _, a := range as {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.RTM_NEWADDR,