summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/epsocket/BUILD2
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go136
-rw-r--r--pkg/sentry/socket/epsocket/stack.go57
-rw-r--r--pkg/sentry/socket/hostinet/socket.go37
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go10
-rw-r--r--pkg/sentry/socket/hostinet/stack.go93
-rw-r--r--pkg/sentry/socket/netfilter/BUILD24
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go286
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go62
-rw-r--r--pkg/sentry/socket/netlink/socket.go45
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD3
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go3
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go23
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go31
-rw-r--r--pkg/sentry/socket/socket.go40
-rw-r--r--pkg/sentry/socket/unix/io.go4
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go2
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go2
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go36
-rw-r--r--pkg/sentry/socket/unix/unix.go20
21 files changed, 780 insertions, 137 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 2b03ea87c..3300f9a6b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -9,6 +9,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/binary",
"//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
index 1f014f399..e927821e1 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
+ "//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
"//pkg/syserr",
@@ -38,6 +39,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index e57aed927..635042263 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -43,6 +43,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -290,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// GetAddress reads an sockaddr struct from the given address and converts it
-// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
-// addresses.
-func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
+// AF_INET6 addresses.
+//
+// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
+//
+// AddressAndFamily returns an address, its family.
+func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
// Get the rest of the fields based on the address family.
@@ -309,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
// Drop the terminating NUL (if one exists) and everything after
// it for filesystem (non-abstract) addresses.
@@ -320,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
}
return tcpip.FullAddress{
Addr: tcpip.Address(path),
- }, nil
+ }, family, nil
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -333,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
Addr: bytesToIPAddress(a.Addr[:]),
Port: ntohs(a.Port),
}
- return out, nil
+ return out, family, nil
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -349,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
if isLinkLocal(out.Addr) {
out.NIC = tcpip.NICID(a.Scope_id)
}
- return out, nil
+ return out, family, nil
case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, nil
+ return tcpip.FullAddress{}, family, nil
default:
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
}
}
@@ -428,6 +433,11 @@ func (i *ioSequencePayload) Size() int {
return int(i.src.NumBytes())
}
+// DropFirst drops the first n bytes from underlying src.
+func (i *ioSequencePayload) DropFirst(n int) {
+ i.src = i.src.DropFirst(int(n))
+}
+
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
f := &ioSequencePayload{ctx: ctx, src: src}
@@ -476,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -509,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, true /* strict */)
+ addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -547,7 +564,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, wq, terr := s.Endpoint.Accept()
if terr != nil {
@@ -574,7 +591,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
ns.SetFlags(flags.Settable())
}
- var addr interface{}
+ var addr linux.SockAddr
var addrLen uint32
if peerRequested {
// Get address of the peer and write it to peer slice.
@@ -624,7 +641,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for epsocket.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -655,6 +672,33 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
return val, nil
}
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ info, err := netfilter.GetInfo(t, s.Endpoint, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return entries, nil
+
+ }
+ }
+
return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
}
@@ -1028,7 +1072,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(linux.SockAddrInet).Addr, nil
+ return a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1658,7 +1702,7 @@ func isLinkLocal(addr tcpip.Address) bool {
}
// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
switch family {
case linux.AF_UNIX:
var out linux.SockAddrUnix
@@ -1674,15 +1718,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
// address length is the max. Abstract and empty paths always return
// the full exact length.
if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return out, uint32(2 + l)
+ return &out, uint32(2 + l)
}
- return out, uint32(3 + l)
+ return &out, uint32(3 + l)
case linux.AF_INET:
var out linux.SockAddrInet
copy(out.Addr[:], addr.Addr)
out.Family = linux.AF_INET
out.Port = htons(addr.Port)
- return out, uint32(binary.Size(out))
+ return &out, uint32(binary.Size(out))
case linux.AF_INET6:
var out linux.SockAddrInet6
if len(addr.Addr) == 4 {
@@ -1698,7 +1742,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
if isLinkLocal(addr.Addr) {
out.Scope_id = uint32(addr.NIC)
}
- return out, uint32(binary.Size(out))
+ return &out, uint32(binary.Size(out))
default:
return nil, 0
}
@@ -1706,7 +1750,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -1718,7 +1762,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -1791,7 +1835,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
-func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
@@ -1839,7 +1883,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if err == nil {
s.updateTimestamp()
}
- var addr interface{}
+ var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
@@ -1914,7 +1958,7 @@ func (s *SocketOperations) updateTimestamp() {
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -1990,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to, true /* strict */)
+ addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
@@ -1998,28 +2042,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
addr = &addrBuf
}
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, syserr.FromError(err)
- }
-
opts := tcpip.WriteOptions{
To: addr,
More: flags&linux.MSG_MORE != 0,
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ v := &ioSequencePayload{t, src}
+ n, resCh, err := s.Endpoint.Write(v, opts)
if resCh != nil {
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err)
}
- n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ if err == nil && (n >= int64(v.Size()) || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2033,18 +2071,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
s.EventRegister(&e, waiter.EventOut)
defer s.EventUnregister(&e)
- v.TrimFront(int(n))
+ v.DropFirst(int(n))
total := n
for {
- n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
- v.TrimFront(int(n))
+ n, _, err = s.Endpoint.Write(v, opts)
+ v.DropFirst(int(n))
total += n
if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2252,19 +2290,19 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case syscall.SIOCGIFMAP:
// Gets the hardware parameters of the device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFTXQLEN:
// Gets the transmit queue length of the device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFDSTADDR:
// Gets the destination address of a point-to-point device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFBRDADDR:
// Gets the broadcast address of a device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFNETMASK:
// Gets the network mask of a device.
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
index 8fe489c0e..7cf7ff735 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -18,7 +18,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -143,3 +146,57 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
func (s *Stack) Statistics(stat interface{}, arg string) error {
return syserr.ErrEndpointOperation.ToError()
}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ var routeTable []inet.Route
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ var family uint8
+ switch len(rt.Destination.ID()) {
+ case header.IPv4AddressSize:
+ family = linux.AF_INET
+ case header.IPv6AddressSize:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in route %+v", rt)
+ continue
+ }
+
+ routeTable = append(routeTable, inet.Route{
+ Family: family,
+ DstLen: uint8(rt.Destination.Prefix()), // The CIDR prefix for the destination.
+
+ // Always return unspecified protocol since we have no notion of
+ // protocol for routes.
+ Protocol: linux.RTPROT_UNSPEC,
+ // Set statically to LINK scope for now.
+ //
+ // TODO(gvisor.dev/issue/595): Set scope for routes.
+ Scope: linux.RT_SCOPE_LINK,
+ Type: linux.RTN_UNICAST,
+
+ DstAddr: []byte(rt.Destination.ID()),
+ OutputInterface: int32(rt.NIC),
+ GatewayAddr: []byte(rt.Gateway),
+ })
+ }
+
+ return routeTable
+}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() (iptables.IPTables, error) {
+ return s.Stack.IPTables(), nil
+}
+
+// FillDefaultIPTables sets the stack's iptables to the default tables, which
+// allow and do not modify all traffic.
+func (s *Stack) FillDefaultIPTables() {
+ netfilter.FillDefaultIPTables(s.Stack)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {
+ s.Stack.Resume()
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7f69406b7..92beb1bcf 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
- var peerAddr []byte
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ var peerAddr linux.SockAddr
+ var peerAddrBuf []byte
var peerAddrlen uint32
var peerAddrPtr *byte
var peerAddrlenPtr *uint32
if peerRequested {
- peerAddr = make([]byte, sizeofSockaddr)
- peerAddrlen = uint32(len(peerAddr))
- peerAddrPtr = &peerAddr[0]
+ peerAddrBuf = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddrBuf))
+ peerAddrPtr = &peerAddrBuf[0]
peerAddrlenPtr = &peerAddrlen
}
@@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
if peerRequested {
- peerAddr = peerAddr[:peerAddrlen]
+ peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
}
if syscallErr != nil {
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
@@ -272,7 +273,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
@@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}
- var senderAddr []byte
+ var senderAddr linux.SockAddr
+ var senderAddrBuf []byte
if senderRequested {
- senderAddr = make([]byte, sizeofSockaddr)
+ senderAddrBuf = make([]byte, sizeofSockaddr)
}
var msgFlags int
@@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if dsts.NumBlocks() == 1 {
// Skip allocating []syscall.Iovec.
- return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr)
+ return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf)
}
iovs := iovecsFromBlockSeq(dsts)
@@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
}
- if len(senderAddr) != 0 {
- msg.Name = &senderAddr[0]
- msg.Namelen = uint32(len(senderAddr))
+ if len(senderAddrBuf) != 0 {
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(len(senderAddrBuf))
}
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
return 0, err
}
- senderAddr = senderAddr[:msg.Namelen]
+ senderAddrBuf = senderAddrBuf[:msg.Namelen]
msgFlags = int(msg.Flags)
return n, nil
})
@@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We don't allow control messages.
msgFlags &^= linux.MSG_CTRUNC
- return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
+ if senderRequested {
+ senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
+ }
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err)
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
index 6c69ba9c7..e69ec38c2 100644
--- a/pkg/sentry/socket/hostinet/socket_unsafe.go
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -18,10 +18,12 @@ import (
"syscall"
"unsafe"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -91,25 +93,25 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
if errno != 0 {
return nil, 0, syserr.FromError(errno)
}
- return addr[:addrlen], addrlen, nil
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
if errno != 0 {
return nil, 0, syserr.FromError(errno)
}
- return addr[:addrlen], addrlen, nil
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
}
func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index cc1f66fa1..3a4fdec47 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -46,6 +46,7 @@ type Stack struct {
// Stack is immutable.
interfaces map[int32]inet.Interface
interfaceAddrs map[int32][]inet.InterfaceAddr
+ routes []inet.Route
supportsIPv6 bool
tcpRecvBufSize inet.TCPBufferSize
tcpSendBufSize inet.TCPBufferSize
@@ -66,6 +67,10 @@ func (s *Stack) Configure() error {
return err
}
+ if err := addHostRoutes(s); err != nil {
+ return err
+ }
+
if _, err := os.Stat("/proc/net/if_inet6"); err == nil {
s.supportsIPv6 = true
}
@@ -161,6 +166,60 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return nil
}
+// ExtractHostRoutes populates the given routes slice with the data from the
+// host route table.
+func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) {
+ var routes []inet.Route
+ for _, routeMsg := range routeMsgs {
+ if routeMsg.Header.Type != syscall.RTM_NEWROUTE {
+ continue
+ }
+
+ var ifRoute syscall.RtMsg
+ binary.Unmarshal(routeMsg.Data[:syscall.SizeofRtMsg], usermem.ByteOrder, &ifRoute)
+ inetRoute := inet.Route{
+ Family: ifRoute.Family,
+ DstLen: ifRoute.Dst_len,
+ SrcLen: ifRoute.Src_len,
+ TOS: ifRoute.Tos,
+ Table: ifRoute.Table,
+ Protocol: ifRoute.Protocol,
+ Scope: ifRoute.Scope,
+ Type: ifRoute.Type,
+ Flags: ifRoute.Flags,
+ }
+
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct rtmsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&routeMsg)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid rtattrs: %v", err)
+ }
+
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.RTA_DST:
+ inetRoute.DstAddr = attr.Value
+ case syscall.RTA_SRC:
+ inetRoute.SrcAddr = attr.Value
+ case syscall.RTA_GATEWAY:
+ inetRoute.GatewayAddr = attr.Value
+ case syscall.RTA_OIF:
+ expected := int(binary.Size(inetRoute.OutputInterface))
+ if len(attr.Value) != expected {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
+ }
+ binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface)
+ }
+ }
+
+ routes = append(routes, inetRoute)
+ }
+
+ return routes, nil
+}
+
func addHostInterfaces(s *Stack) error {
links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK)
if err != nil {
@@ -175,6 +234,20 @@ func addHostInterfaces(s *Stack) error {
return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs)
}
+func addHostRoutes(s *Stack) error {
+ routes, err := doNetlinkRouteRequest(syscall.RTM_GETROUTE)
+ if err != nil {
+ return fmt.Errorf("RTM_GETROUTE failed: %v", err)
+ }
+
+ s.routes, err = ExtractHostRoutes(routes)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC)
if err != nil {
@@ -202,12 +275,20 @@ func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) {
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
- return s.interfaces
+ interfaces := make(map[int32]inet.Interface)
+ for k, v := range s.interfaces {
+ interfaces[k] = v
+ }
+ return interfaces
}
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- return s.interfaceAddrs
+ addrs := make(map[int32][]inet.InterfaceAddr)
+ for k, v := range s.interfaceAddrs {
+ addrs[k] = append([]inet.InterfaceAddr(nil), v...)
+ }
+ return addrs
}
// SupportsIPv6 implements inet.Stack.SupportsIPv6.
@@ -249,3 +330,11 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
func (s *Stack) Statistics(stat interface{}, arg string) error {
return syserror.EOPNOTSUPP
}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ return append([]inet.Route(nil), s.routes...)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
new file mode 100644
index 000000000..354a0d6ee
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -0,0 +1,24 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "netfilter",
+ srcs = [
+ "netfilter.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter",
+ # This target depends on netstack and should only be used by epsocket,
+ # which is allowed to depend on netstack.
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usermem",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/tcpip/iptables",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
new file mode 100644
index 000000000..9f87c32f1
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -0,0 +1,286 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netfilter helps the sentry interact with netstack's netfilter
+// capabilities.
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// errorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const errorTargetName = "ERROR"
+
+// metadata is opaque to netstack. It holds data that we need to translate
+// between Linux's and netstack's iptables representations.
+type metadata struct {
+ HookEntry [linux.NF_INET_NUMHOOKS]uint32
+ Underflow [linux.NF_INET_NUMHOOKS]uint32
+ NumEntries uint32
+ Size uint32
+}
+
+// GetInfo returns information about iptables.
+func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+ // Read in the struct and table name.
+ var info linux.IPTGetinfo
+ if _, err := t.CopyIn(outPtr, &info); err != nil {
+ return linux.IPTGetinfo{}, syserr.FromError(err)
+ }
+
+ // Find the appropriate table.
+ table, err := findTable(ep, info.TableName())
+ if err != nil {
+ return linux.IPTGetinfo{}, err
+ }
+
+ // Get the hooks that apply to this table.
+ info.ValidHooks = table.ValidHooks()
+
+ // Grab the metadata struct, which is used to store information (e.g.
+ // the number of entries) that applies to the user's encoding of
+ // iptables, but not netstack's.
+ metadata := table.Metadata().(metadata)
+
+ // Set values from metadata.
+ info.HookEntry = metadata.HookEntry
+ info.Underflow = metadata.Underflow
+ info.NumEntries = metadata.NumEntries
+ info.Size = metadata.Size
+
+ return info, nil
+}
+
+// GetEntries returns netstack's iptables rules encoded for the iptables tool.
+func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+ // Read in the struct and table name.
+ var userEntries linux.IPTGetEntries
+ if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ return linux.KernelIPTGetEntries{}, syserr.FromError(err)
+ }
+
+ // Find the appropriate table.
+ table, err := findTable(ep, userEntries.TableName())
+ if err != nil {
+ return linux.KernelIPTGetEntries{}, err
+ }
+
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary(userEntries.TableName(), table)
+ if err != nil {
+ return linux.KernelIPTGetEntries{}, err
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+
+ return entries, nil
+}
+
+func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) {
+ ipt, err := ep.IPTables()
+ if err != nil {
+ return iptables.Table{}, syserr.FromError(err)
+ }
+ table, ok := ipt.Tables[tableName]
+ if !ok {
+ return iptables.Table{}, syserr.ErrInvalidArgument
+ }
+ return table, nil
+}
+
+// FillDefaultIPTables sets stack's IPTables to the default tables and
+// populates them with metadata.
+func FillDefaultIPTables(stack *stack.Stack) {
+ ipt := iptables.DefaultTables()
+
+ // In order to fill in the metadata, we have to translate ipt from its
+ // netstack format to Linux's giant-binary-blob format.
+ for name, table := range ipt.Tables {
+ _, metadata, err := convertNetstackToBinary(name, table)
+ if err != nil {
+ panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+ }
+ table.SetMetadata(metadata)
+ ipt.Tables[name] = table
+ }
+
+ stack.SetIPTables(ipt)
+}
+
+// convertNetstackToBinary converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a bit, reading some offsets,
+// jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+ // Return values.
+ var entries linux.KernelIPTGetEntries
+ var meta metadata
+
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(name) {
+ return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ }
+ copy(entries.Name[:], name)
+
+ // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of
+ // these chains ends with an unconditional policy entry.
+ for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ {
+ chain, ok := table.BuiltinChains[hook]
+ if !ok {
+ // This table doesn't support this hook.
+ continue
+ }
+
+ // Sanity check.
+ if len(chain.Rules) < 1 {
+ return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ }
+
+ for ruleIdx, rule := range chain.Rules {
+ // If this is the first rule of a builtin chain, set
+ // the metadata hook entry point.
+ if ruleIdx == 0 {
+ meta.HookEntry[hook] = entries.Size
+ }
+
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+ entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+
+ // The underflow rule is the last rule in the chain,
+ // and is an unconditional rule (i.e. it matches any
+ // packet). This is enforced when saving iptables.
+ if ruleIdx == len(chain.Rules)-1 {
+ meta.Underflow[hook] = entries.Size
+ }
+
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ meta.NumEntries++
+ }
+
+ }
+
+ // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of
+ // these starts with an error node holding the chain's name and ends
+ // with an unconditional return.
+
+ // Lastly, each table ends with an unconditional error target rule as
+ // its final entry.
+ errorEntry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ var errorTarget linux.XTErrorTarget
+ errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget
+ copy(errorTarget.ErrorName[:], errorTargetName)
+ copy(errorTarget.Target.Name[:], errorTargetName)
+
+ // Serialize and add it to the list of entries.
+ errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget)
+ errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...)
+ errorEntry.NextOffset += uint16(len(serializedErrorTarget))
+
+ entries.Size += uint32(errorEntry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, errorEntry)
+ meta.NumEntries++
+ meta.Size = entries.Size
+
+ return entries, meta, nil
+}
+
+func marshalMatcher(matcher iptables.Matcher) []byte {
+ switch matcher.(type) {
+ default:
+ // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so
+ // any call to marshalMatcher will panic.
+ panic(fmt.Errorf("unknown matcher of type %T", matcher))
+ }
+}
+
+func marshalTarget(target iptables.Target) []byte {
+ switch target.(type) {
+ case iptables.UnconditionalAcceptTarget:
+ return marshalUnconditionalAcceptTarget()
+ default:
+ panic(fmt.Errorf("unknown target of type %T", target))
+ }
+}
+
+func marshalUnconditionalAcceptTarget() []byte {
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ Verdict: translateStandardVerdict(iptables.Accept),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateStandardVerdict translates verdicts the same way as the iptables
+// tool.
+func translateStandardVerdict(verdict iptables.Verdict) int32 {
+ switch verdict {
+ case iptables.Accept:
+ return -linux.NF_ACCEPT - 1
+ case iptables.Drop:
+ return -linux.NF_DROP - 1
+ case iptables.Queue:
+ return -linux.NF_QUEUE - 1
+ case iptables.Return:
+ return linux.NF_RETURN
+ case iptables.Jump:
+ // TODO(gvisor.dev/issue/170): Support Jump.
+ panic("Jump isn't supported yet")
+ default:
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ }
+}
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index fb1ff329c..cc70ac237 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -110,7 +110,7 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader
m.PutAttr(linux.IFLA_ADDRESS, mac)
m.PutAttr(linux.IFLA_BROADCAST, brd)
- // TODO(b/68878065): There are many more attributes.
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
}
return nil
@@ -151,13 +151,69 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader
m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
- // TODO(b/68878065): There are many more attributes.
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
}
}
return nil
}
+// dumpRoutes handles RTM_GETROUTE + NLM_F_DUMP requests.
+func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETROUTE dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network routes.
+ return nil
+ }
+
+ for _, rt := range stack.RouteTable() {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWROUTE,
+ })
+
+ m.Put(linux.RouteMessage{
+ Family: rt.Family,
+ DstLen: rt.DstLen,
+ SrcLen: rt.SrcLen,
+ TOS: rt.TOS,
+
+ // Always return the main table since we don't have multiple
+ // routing tables.
+ Table: linux.RT_TABLE_MAIN,
+ Protocol: rt.Protocol,
+ Scope: rt.Scope,
+ Type: rt.Type,
+
+ Flags: rt.Flags,
+ })
+
+ m.PutAttr(254, []byte{123})
+ if rt.DstLen > 0 {
+ m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ }
+ if rt.SrcLen > 0 {
+ m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ }
+ if rt.OutputInterface != 0 {
+ m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ }
+ if len(rt.GatewayAddr) > 0 {
+ m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ }
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+ }
+
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
// All messages start with a 1 byte protocol family.
@@ -186,6 +242,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH
return p.dumpLinks(ctx, hdr, data, ms)
case linux.RTM_GETADDR:
return p.dumpAddrs(ctx, hdr, data, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, hdr, data, ms)
default:
return syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index f3d6c1e9b..d0aab293d 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -271,7 +271,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr
}
// Accept implements socket.Socket.Accept.
-func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Netlink sockets never support accept.
return 0, nil, 0, syserr.ErrNotSupported
}
@@ -289,7 +289,7 @@ func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -379,11 +379,11 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
s.mu.Lock()
defer s.mu.Unlock()
- sa := linux.SockAddrNetlink{
+ sa := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: uint32(s.portID),
}
@@ -391,8 +391,8 @@ func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
- sa := linux.SockAddrNetlink{
+func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ sa := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
// TODO(b/68878065): Support non-kernel peers. For now the peer
// must be the kernel.
@@ -402,8 +402,8 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
- from := linux.SockAddrNetlink{
+func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ from := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: 0,
}
@@ -511,6 +511,19 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
return nil
}
+func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+
+ m.Put(linux.NetlinkErrorMessage{
+ Error: int32(-err.ToLinux().Number()),
+ Header: hdr,
+ })
+ return nil
+
+}
+
// processMessages handles each message in buf, passing it to the protocol
// handler for final handling.
func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
@@ -545,14 +558,20 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error
continue
}
+ ms := NewMessageSet(s.portID, hdr.Seq)
+ var err *syserr.Error
// TODO(b/68877377): ACKs not supported yet.
if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
- return syserr.ErrNotSupported
- }
+ err = syserr.ErrNotSupported
+ } else {
- ms := NewMessageSet(s.portID, hdr.Seq)
- if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil {
- return err
+ err = s.protocol.ProcessMessage(ctx, hdr, data, ms)
+ }
+ if err != nil {
+ ms = NewMessageSet(s.portID, hdr.Seq)
+ if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil {
+ return err
+ }
}
if err := s.sendResponse(ctx, ms); err != nil {
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
index a536f2e44..a3585e10d 100644
--- a/pkg/sentry/socket/rpcinet/notifier/BUILD
+++ b/pkg/sentry/socket/rpcinet/notifier/BUILD
@@ -6,10 +6,11 @@ go_library(
name = "notifier",
srcs = ["notifier.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
"//pkg/sentry/socket/rpcinet/conn",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
index aa157dd51..7efe4301f 100644
--- a/pkg/sentry/socket/rpcinet/notifier/notifier.go
+++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go
@@ -20,6 +20,7 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
"gvisor.dev/gvisor/pkg/waiter"
@@ -76,7 +77,7 @@ func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
}
e := pb.EpollEvent{
- Events: mask.ToLinux() | -syscall.EPOLLET,
+ Events: mask.ToLinux() | unix.EPOLLET,
Fd: fd,
}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index ccaaddbfc..ddb76d9d4 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
payload, se := rpcAccept(t, s.fd, peerRequested)
// Check if we need to block.
@@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
}
file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
+ family: s.family,
+ stype: s.stype,
+ protocol: s.protocol,
wq: &wq,
fd: payload.Fd,
rpcConn: s.rpcConn,
@@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
t.Kernel().RecordSocket(file)
if peerRequested {
- return fd, payload.Address.Address, payload.Address.Length, nil
+ return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
}
return fd, nil, 0, nil
@@ -395,7 +398,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
// SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
// within the sentry.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
@@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetPeerNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetSockNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
@@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
Fd: s.fd,
Length: uint32(dst.NumBytes()),
@@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, err
@@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
return 0, 0, nil, 0, socket.ControlMessages{}, err
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
index 49bd3a220..5dcb6b455 100644
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -30,6 +30,7 @@ import (
type Stack struct {
interfaces map[int32]inet.Interface
interfaceAddrs map[int32][]inet.InterfaceAddr
+ routes []inet.Route
rpcConn *conn.RPCConnection
notifier *notifier.Notifier
}
@@ -69,6 +70,16 @@ func NewStack(fd int32) (*Stack, error) {
return nil, e
}
+ routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err)
+ }
+
+ stack.routes, e = hostinet.ExtractHostRoutes(routes)
+ if e != nil {
+ return nil, e
+ }
+
return stack, nil
}
@@ -89,12 +100,20 @@ func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
- return s.interfaces
+ interfaces := make(map[int32]inet.Interface)
+ for k, v := range s.interfaces {
+ interfaces[k] = v
+ }
+ return interfaces
}
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- return s.interfaceAddrs
+ addrs := make(map[int32][]inet.InterfaceAddr)
+ for k, v := range s.interfaceAddrs {
+ addrs[k] = append([]inet.InterfaceAddr(nil), v...)
+ }
+ return addrs
}
// SupportsIPv6 implements inet.Stack.SupportsIPv6.
@@ -138,3 +157,11 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
func (s *Stack) Statistics(stat interface{}, arg string) error {
return syserr.ErrEndpointOperation.ToError()
}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ return append([]inet.Route(nil), s.routes...)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 0efa58a58..8c250c325 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -20,8 +20,10 @@ package socket
import (
"fmt"
"sync/atomic"
+ "syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -52,7 +54,7 @@ type Socket interface {
// Accept implements the accept4(2) linux syscall.
// Returns fd, real peer address length and error. Real peer address
// length is only set if len(peer) > 0.
- Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error)
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error)
// Bind implements the bind(2) linux syscall.
Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
@@ -64,7 +66,7 @@ type Socket interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
@@ -73,13 +75,13 @@ type Socket interface {
//
// addrLen is the address length to be returned to the application, not
// necessarily the actual length of the address.
- GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+ GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
// GetPeerName implements the getpeername(2) linux syscall.
//
// addrLen is the address length to be returned to the application, not
// necessarily the actual length of the address.
- GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+ GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
// RecvMsg implements the recvmsg(2) linux syscall.
//
@@ -92,7 +94,7 @@ type Socket interface {
// msgFlags. In that case, the caller should set MSG_CTRUNC appropriately.
//
// If err != nil, the recv was not successful.
- RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
// SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
// ownership of the ControlMessage on error.
@@ -340,3 +342,31 @@ func emitUnimplementedEvent(t *kernel.Task, name int) {
t.Kernel().EmitUnimplementedEvent(t)
}
}
+
+// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of
+// the ABI socket address types.
+//
+// Precondition: data must be long enough to represent a socket address of the
+// given family.
+func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
+ switch family {
+ case syscall.AF_INET:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_INET6:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_UNIX:
+ var addr linux.SockAddrUnix
+ binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_NETLINK:
+ var addr linux.SockAddrNetlink
+ binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr)
+ return &addr
+ default:
+ panic(fmt.Sprintf("Unsupported socket family %v", family))
+ }
+}
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 760c7beab..2ec1a662d 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -62,7 +62,7 @@ type EndpointReader struct {
Creds bool
// NumRights is the number of SCM_RIGHTS FDs requested.
- NumRights uintptr
+ NumRights int
// Peek indicates that the data should not be consumed from the
// endpoint.
@@ -70,7 +70,7 @@ type EndpointReader struct {
// MsgSize is the size of the message that was read from. For stream
// sockets, it is the amount read.
- MsgSize uintptr
+ MsgSize int64
// From, if not nil, will be set with the address read from.
From *tcpip.FullAddress
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 73d2df15d..4bd15808a 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -436,7 +436,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
// Stream sockets do not support specifying the endpoint. Seqpacket
// sockets ignore the passed endpoint.
if e.stype == linux.SOCK_STREAM && to != nil {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index c7f7c5b16..0322dec0b 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -99,7 +99,7 @@ func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (Con
// SendMsg writes data and a control message to the specified endpoint.
// This method does not block if the data cannot be written.
-func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
if to == nil {
return e.baseEndpoint.SendMsg(ctx, data, c, nil)
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 7fb9cb1e0..2b0ad6395 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -121,13 +121,13 @@ type Endpoint interface {
// CMTruncated indicates that the numRights hint was used to receive fewer
// than the total available SCM_RIGHTS FDs. Additional truncation may be
// required by the caller.
- RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+ RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, err *syserr.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
//
// SendMsg does not take ownership of any of its arguments on error.
- SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error)
+ SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (int64, *syserr.Error)
// Connect connects this endpoint directly to another.
//
@@ -291,7 +291,7 @@ type Receiver interface {
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+ Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -331,7 +331,7 @@ type queueReceiver struct {
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
var m *message
var notify bool
var err *syserr.Error
@@ -344,13 +344,13 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek
return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
}
src := []byte(m.Data)
- var copied uintptr
+ var copied int64
for i := 0; i < len(data) && len(src) > 0; i++ {
n := copy(data[i], src)
- copied += uintptr(n)
+ copied += int64(n)
src = src[n:]
}
- return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil
+ return copied, int64(len(m.Data)), m.Control, false, m.Address, notify, nil
}
// RecvNotify implements Receiver.RecvNotify.
@@ -401,11 +401,11 @@ type streamQueueReceiver struct {
addr tcpip.FullAddress
}
-func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) {
- var copied uintptr
+func vecCopy(data [][]byte, buf []byte) (int64, [][]byte, []byte) {
+ var copied int64
for len(data) > 0 && len(buf) > 0 {
n := copy(data[0], buf)
- copied += uintptr(n)
+ copied += int64(n)
buf = buf[n:]
data[0] = data[0][n:]
if len(data[0]) == 0 {
@@ -443,7 +443,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
}
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -464,7 +464,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
q.addr = m.Address
}
- var copied uintptr
+ var copied int64
if peek {
// Don't consume control message if we are peeking.
c := q.control.Clone()
@@ -531,7 +531,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
break
}
- var cpd uintptr
+ var cpd int64
cpd, data, q.buffer = vecCopy(data, q.buffer)
copied += cpd
@@ -569,7 +569,7 @@ type ConnectedEndpoint interface {
//
// syserr.ErrWouldBlock can be returned along with a partial write if
// the caller should block to send the rest of the data.
- Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *syserr.Error)
+ Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
// SendNotify notifies the ConnectedEndpoint of a successful Send. This
// must not be called while holding any endpoint locks.
@@ -637,7 +637,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
}
// Send implements ConnectedEndpoint.Send.
-func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
+func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
var l int64
for _, d := range data {
l += int64(len(d))
@@ -665,7 +665,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages,
}
l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate)
- return uintptr(l), notify, err
+ return int64(l), notify, err
}
// SendNotify implements ConnectedEndpoint.SendNotify.
@@ -781,7 +781,7 @@ func (e *baseEndpoint) Connected() bool {
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
+func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (int64, int64, ControlMessages, bool, *syserr.Error) {
e.Lock()
if e.receiver == nil {
@@ -807,7 +807,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
e.Lock()
if !e.Connected() {
e.Unlock()
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index eb262ecaf..0d0cb68df 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
@@ -137,7 +137,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -149,7 +149,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -166,7 +166,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
@@ -199,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, err := s.ep.Accept()
if err != nil {
@@ -223,7 +223,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
ns.SetFlags(flags.Settable())
}
- var addr interface{}
+ var addr linux.SockAddr
var addrLen uint32
if peerRequested {
// Get address of the peer.
@@ -505,7 +505,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -535,7 +535,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Ctx: t,
Endpoint: s.ep,
Creds: wantCreds,
- NumRights: uintptr(numRights),
+ NumRights: numRights,
Peek: peek,
}
if senderRequested {
@@ -543,7 +543,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
var total int64
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
- var from interface{}
+ var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
@@ -578,7 +578,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
for {
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
- var from interface{}
+ var from linux.SockAddr
var fromLen uint32
if r.From != nil {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)