summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netstack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/netstack')
-rw-r--r--pkg/sentry/socket/netstack/BUILD3
-rw-r--r--pkg/sentry/socket/netstack/netstack.go200
-rw-r--r--pkg/sentry/socket/netstack/netstack_state.go31
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go7
-rw-r--r--pkg/sentry/socket/netstack/stack.go2
5 files changed, 155 insertions, 88 deletions
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e828982eb..075f61cda 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"device.go",
"netstack.go",
+ "netstack_state.go",
"netstack_vfs2.go",
"provider.go",
"provider_vfs2.go",
@@ -42,13 +43,13 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/header",
"//pkg/tcpip/link/tun",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9b844b0c0..030c6c8e4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -56,12 +56,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -275,6 +274,7 @@ var Metrics = tcpip.Stats{
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."),
+ SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
@@ -379,9 +379,9 @@ type socketOpsCommon struct {
// timestampValid indicates whether timestamp for SIOCGSTAMP has been
// set. It is protected by readMu.
timestampValid bool
- // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // timestamp holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
- timestampNS int64
+ timestamp time.Time `state:".(int64)"`
// TODO(b/153685824): Move this to SocketOptions.
// sockOptInq corresponds to TCP_INQ.
@@ -411,13 +411,25 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes()
var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes()
var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes()
-// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
-// netstack representation taking any addresses into account.
-func bytesToIPAddress(addr []byte) tcpip.Address {
- if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
- return ""
+// minSockAddrLen returns the minimum length in bytes of a socket address for
+// the socket's family.
+func (s *socketOpsCommon) minSockAddrLen() int {
+ const addressFamilySize = 2
+
+ switch s.family {
+ case linux.AF_UNIX:
+ return addressFamilySize
+ case linux.AF_INET:
+ return sockAddrInetSize
+ case linux.AF_INET6:
+ return sockAddrInet6Size
+ case linux.AF_PACKET:
+ return sockAddrLinkSize
+ case linux.AF_UNSPEC:
+ return addressFamilySize
+ default:
+ panic(fmt.Sprintf("s.family unrecognized = %d", s.family))
}
- return tcpip.Address(addr)
}
func (s *socketOpsCommon) isPacketBased() bool {
@@ -448,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
t := kernel.TaskFromContext(ctx)
start := t.Kernel().MonotonicClock().Now()
deadline := start.Add(v.Timeout)
- t.BlockWithDeadline(ch, true, deadline)
+ _ = t.BlockWithDeadline(ch, true, deadline)
}
}
@@ -459,7 +471,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return int64(n), linuxerr.ErrWouldBlock
}
if err != nil {
return 0, err.ToError()
@@ -468,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// WriteTo implements fs.FileOperations.WriteTo.
-func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
defer s.readMu.Unlock()
@@ -492,14 +504,14 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
r := src.Reader(ctx)
n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
if n < src.NumBytes() {
- return n, syserror.ErrWouldBlock
+ return n, linuxerr.ErrWouldBlock
}
return n, nil
@@ -523,7 +535,7 @@ func (l *limitedPayloader) Len() int {
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
f := limitedPayloader{
inner: io.LimitedReader{
R: r,
@@ -546,16 +558,21 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.Endpoint.Readiness(mask)
}
-func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
+// checkFamily returns true iff the specified address family may be used with
+// the socket.
+//
+// If exact is true, then the specified address family must be an exact match
+// with the socket's family.
+func (s *socketOpsCommon) checkFamily(family uint16, exact bool) bool {
if family == uint16(s.family) {
- return nil
+ return true
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
if !s.Endpoint.SocketOptions().GetV6Only() {
- return nil
+ return true
}
}
- return syserr.ErrInvalidArgument
+ return false
}
// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
@@ -588,8 +605,8 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
return syserr.TranslateNetstackError(err)
}
- if err := s.checkFamily(family, false /* exact */); err != nil {
- return err
+ if !s.checkFamily(family, false /* exact */) {
+ return syserr.ErrInvalidArgument
}
addr = s.mapFamily(addr, family)
@@ -629,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < 2 {
return syserr.ErrInvalidArgument
}
@@ -647,23 +664,24 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
- if a.Protocol != uint16(s.protocol) {
- return syserr.ErrInvalidArgument
- }
-
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: socket.Ntohs(a.Protocol),
}
} else {
+ if s.minSockAddrLen() > len(sockaddr) {
+ return syserr.ErrInvalidArgument
+ }
+
var err *syserr.Error
addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
- if err = s.checkFamily(family, true /* exact */); err != nil {
- return err
+ if !s.checkFamily(family, true /* exact */) {
+ return syserr.ErrAddressFamilyNotSupported
}
addr = s.mapFamily(addr, family)
@@ -688,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Listen implements the linux syscall listen(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
}
@@ -779,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
f, err := ConvertShutdown(how)
if err != nil {
return err
@@ -860,7 +878,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1345,6 +1363,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
return &v, nil
+ case linux.IPV6_RECVPKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo()))
+ return &v, nil
+
case linux.IP6T_ORIGINAL_DST:
if outLen < sockAddrInet6Size {
return nil, syserr.ErrInvalidArgument
@@ -1368,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true)
if err != nil {
return nil, err
}
@@ -1388,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1408,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
@@ -1425,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
return nil, syserr.ErrUnknownProtocolOption
@@ -1565,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false)
if err != nil {
return nil, err
}
@@ -1585,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1605,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber)
@@ -2046,7 +2072,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
return syserr.ErrInvalidEndpointState
- } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial {
return syserr.ErrInvalidEndpointState
}
@@ -2101,6 +2127,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
return nil
+ case linux.IPV6_RECVPKTINFO:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2143,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true)
case linux.IP6T_SO_SET_ADD_COUNTERS:
log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported")
@@ -2386,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
@@ -2490,7 +2525,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
linux.IPV6_RECVPATHMTU,
- linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
@@ -2559,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2571,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2716,6 +2750,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr
TClass: readCM.TClass,
HasIPPacketInfo: readCM.HasIPPacketInfo,
PacketInfo: readCM.PacketInfo,
+ HasIPv6PacketInfo: readCM.HasIPv6PacketInfo,
+ IPv6PacketInfo: readCM.IPv6PacketInfo,
OriginalDstAddress: readCM.OriginalDstAddress,
SockErr: readCM.SockErr,
},
@@ -2730,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
- s.timestampNS = cm.Timestamp
+ s.timestamp = cm.Timestamp
}
}
@@ -2789,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int,
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
if flags&linux.MSG_ERRQUEUE != 0 {
return s.recvErr(t, dst)
}
@@ -2873,8 +2909,8 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
if err != nil {
return 0, err
}
- if err := s.checkFamily(family, false /* exact */); err != nil {
- return 0, err
+ if !s.checkFamily(family, false /* exact */) {
+ return 0, syserr.ErrInvalidArgument
}
addrBuf = s.mapFamily(addrBuf, family)
@@ -2951,10 +2987,10 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
- return 0, syserror.ENOENT
+ return 0, linuxerr.ENOENT
}
- tv := linux.NsecToTimeval(s.timestampNS)
+ tv := linux.NsecToTimeval(s.timestamp.UnixNano())
_, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
@@ -3061,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// interfaceIoctl implements interface requests.
-func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
var (
iface inet.Interface
index int32
@@ -3069,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
)
// Find the relevant device.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice
}
@@ -3080,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4]))
- if iface, ok := stack.Interfaces()[index]; ok {
+ if iface, ok := stk.Interfaces()[index]; ok {
ifr.SetName(iface.Name)
return nil
}
@@ -3088,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// Find the relevant device.
- for index, iface = range stack.Interfaces() {
+ for index, iface = range stk.Interfaces() {
if iface.Name == ifr.Name() {
found = true
break
@@ -3121,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
case linux.SIOCGIFFLAGS:
- f, err := interfaceStatusFlags(stack, iface.Name)
+ f, err := interfaceStatusFlags(stk, iface.Name)
if err != nil {
return err
}
@@ -3131,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3167,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3199,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice.ToError()
}
if ifc.Ptr == 0 {
- ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq)
return nil
}
max := ifc.Len
ifc.Len = 0
- for key, ifaceAddrs := range stack.InterfaceAddrs() {
- iface := stack.Interfaces()[key]
+ for key, ifaceAddrs := range stk.InterfaceAddrs() {
+ iface := stk.Interfaces()[key]
for _, ifaceAddr := range ifaceAddrs {
// Don't write past the end of the buffer.
if ifc.Len+int32(linux.SizeOfIFReq) > max {
@@ -3332,10 +3368,10 @@ func (s *socketOpsCommon) State() uint32 {
}
case isUDPSocket(s.skType, s.protocol):
// UDP socket.
- switch udp.EndpointState(s.Endpoint.State()) {
- case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ switch transport.DatagramEndpointState(s.Endpoint.State()) {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed:
return linux.TCP_CLOSE
- case udp.StateConnected:
+ case transport.DatagramEndpointStateConnected:
return linux.TCP_ESTABLISHED
default:
return 0
diff --git a/pkg/sentry/socket/netstack/netstack_state.go b/pkg/sentry/socket/netstack/netstack_state.go
new file mode 100644
index 000000000..591e00d42
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack_state.go
@@ -0,0 +1,31 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "time"
+)
+
+func (s *socketOpsCommon) saveTimestamp() int64 {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ return s.timestamp.UnixNano()
+}
+
+func (s *socketOpsCommon) loadTimestamp(nsec int64) {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index edc160b1b..3cdf29b80 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -113,7 +112,7 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.
}
n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return int64(n), linuxerr.ErrWouldBlock
}
if err != nil {
return 0, err.ToError()
@@ -132,14 +131,14 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
r := src.Reader(ctx)
n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
if n < src.NumBytes() {
- return n, syserror.ErrWouldBlock
+ return n, linuxerr.ErrWouldBlock
}
return n, nil
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 208ab9909..ea199f223 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
// Attach address to interface.
nicID := tcpip.NICID(idx)
- if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}