summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2020-01-13 12:22:15 -0800
committerKevin Krakauer <krakauer@google.com>2020-01-13 12:22:15 -0800
commit31e49f4b19309259baeeb63e7b6ef41f8edd6d35 (patch)
treead6cae1ad4f173946eb68156653bea84d7254f81 /pkg/sentry/socket
parentd147e6d1b29d25607bcdcdb0beddb5122fea085e (diff)
parentb30cfb1df72e201c6caf576bbef8fcc968df2d41 (diff)
Merge branch 'master' into iptables-write-input-drop
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go6
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go2
-rw-r--r--pkg/sentry/socket/netlink/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/port/port.go3
-rw-r--r--pkg/sentry/socket/netlink/socket.go31
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go56
-rw-r--r--pkg/sentry/socket/rpcinet/conn/BUILD1
-rw-r--r--pkg/sentry/socket/rpcinet/conn/conn.go2
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD1
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go2
-rw-r--r--pkg/sentry/socket/unix/io.go13
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go3
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go26
-rw-r--r--pkg/sentry/socket/unix/unix.go23
18 files changed, 134 insertions, 43 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index af1a4e95f..4301b697c 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -471,6 +471,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IP:
switch h.Type {
case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
i += AlignUp(length, width)
@@ -481,6 +484,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_IPV6:
switch h.Type {
case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += AlignUp(length, width)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 014dfa625..37f726295 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -460,7 +460,7 @@ func hookFromLinux(hook int) iptables.Hook {
case linux.NF_INET_POST_ROUTING:
return iptables.Postrouting
}
- panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain"))
+ panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
// printReplace prints information about the struct ipt_replace in optVal. It
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 79589e3c8..103933144 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -22,12 +22,12 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 463544c1a..2d9f4ba9b 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -8,6 +8,7 @@ go_library(
srcs = ["port.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
index e9d3275b1..2cd3afc22 100644
--- a/pkg/sentry/socket/netlink/port/port.go
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -24,7 +24,8 @@ import (
"fmt"
"math"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// maxPorts is a sanity limit on the maximum number of ports to allocate per
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 4a1b87a9a..cea56f4ed 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -17,7 +17,6 @@ package netlink
import (
"math"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
@@ -29,12 +28,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -500,29 +499,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
trunc := flags&linux.MSG_TRUNC != 0
r := unix.EndpointReader{
+ Ctx: t,
Endpoint: s.ep,
Peek: flags&linux.MSG_PEEK != 0,
}
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
// If MSG_TRUNC is set with a zero byte destination then we still need
// to read the message and discard it, or in the case where MSG_PEEK is
// set, leave it be. In both cases the full message length must be
- // returned. However, the memory manager for the destination will not read
- // the endpoint if the destination is zero length.
- //
- // In order for the endpoint to be read when the destination size is zero,
- // we must cause a read of the endpoint by using a separate fake zero
- // length block sequence and calling the EndpointReader directly.
+ // returned.
if trunc && dst.Addrs.NumBytes() == 0 {
- // Perform a read to a zero byte block sequence. We can ignore the
- // original destination since it was zero bytes. The length returned by
- // ReadToBlocks is ignored and we return the full message length to comply
- // with MSG_TRUNC.
- _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
- return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
}
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -540,7 +539,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e414d8055..f78784569 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 85d5b7424..f7b069eed 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,7 +29,6 @@ import (
"io"
"math"
"reflect"
- "sync"
"syscall"
"time"
@@ -49,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -222,17 +222,25 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -985,13 +993,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
+ if v == 0 {
return []byte{}, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1221,12 +1239,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.V6OnlyOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ var o uint32
+ if v {
+ o = 1
+ }
+ return int32(o), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1455,7 +1476,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
@@ -1649,7 +1683,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD
index 23eadcb1b..b2677c659 100644
--- a/pkg/sentry/socket/rpcinet/conn/BUILD
+++ b/pkg/sentry/socket/rpcinet/conn/BUILD
@@ -10,6 +10,7 @@ go_library(
deps = [
"//pkg/binary",
"//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/unet",
"@com_github_golang_protobuf//proto:go_default_library",
diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go
index 356adad99..02f39c767 100644
--- a/pkg/sentry/socket/rpcinet/conn/conn.go
+++ b/pkg/sentry/socket/rpcinet/conn/conn.go
@@ -17,12 +17,12 @@ package conn
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
"github.com/golang/protobuf/proto"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/unet"
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
index a3585e10d..a5954f22b 100644
--- a/pkg/sentry/socket/rpcinet/notifier/BUILD
+++ b/pkg/sentry/socket/rpcinet/notifier/BUILD
@@ -10,6 +10,7 @@ go_library(
deps = [
"//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
"//pkg/sentry/socket/rpcinet/conn",
+ "//pkg/sync",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
index 7efe4301f..82b75d6dd 100644
--- a/pkg/sentry/socket/rpcinet/notifier/notifier.go
+++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go
@@ -17,12 +17,12 @@ package notifier
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 2ec1a662d..2447f24ef 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -83,6 +83,19 @@ type EndpointReader struct {
ControlTrunc bool
}
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 788ad70d2..d7ba95dff 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/ilist",
"//pkg/refs",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index dea11e253..9e6fbc111 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -15,10 +15,9 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/waiter"
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e27b1c714..5dcd3d95e 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -15,9 +15,8 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 529a7a7a9..fcc0da332 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,11 +16,11 @@
package transport
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -175,17 +175,25 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
// State returns the current state of the socket, as represented by Linux in
// procfs.
@@ -851,11 +859,19 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 885758054..91effe89a 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -544,8 +544,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if senderRequested {
r.From = &tcpip.FullAddress{}
}
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
var total int64
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -580,7 +599,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {