diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/BUILD | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/control/control.go | 32 | ||||
-rw-r--r-- | pkg/sentry/socket/control/control_test.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 39 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/targets.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/BUILD | 3 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 200 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_state.go | 31 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_vfs2.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 40 | ||||
-rw-r--r-- | pkg/sentry/socket/socket_state.go | 27 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 9 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/queue.go | 10 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 11 |
20 files changed, 290 insertions, 143 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 7ee89a735..00f925166 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "socket", - srcs = ["socket.go"], + srcs = [ + "socket.go", + "socket_state.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 00a5e729a..6077b2150 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -29,10 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "time" ) -const maxInt = int(^uint(0) >> 1) - // SCMCredentials represents a SCM_CREDENTIALS socket control message. type SCMCredentials interface { transport.CredentialsControlMessage @@ -78,7 +77,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { } // Files implements SCMRights.Files. -func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) { +func (fs *RightsFiles) Files(_ context.Context, max int) (RightsFiles, bool) { n := max var trunc bool if l := len(*fs); n > l { @@ -124,7 +123,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 break } - fds = append(fds, int32(fd)) + fds = append(fds, fd) } return fds, trunc } @@ -300,8 +299,8 @@ func alignSlice(buf []byte, align uint) []byte { } // PackTimestamp packs a SO_TIMESTAMP socket control message. -func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { - timestampP := linux.NsecToTimeval(timestamp) +func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp.UnixNano()) return putCmsgStruct( buf, linux.SOL_SOCKET, @@ -355,6 +354,17 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketIn ) } +// PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message. +func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_PKTINFO, + t.Arch().Width(), + packetInfo, + ) +} + // PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { var level uint32 @@ -412,6 +422,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) } + if cmsgs.IP.HasIPv6PacketInfo { + buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf) + } + if cmsgs.IP.OriginalDstAddress != nil { buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } @@ -453,6 +467,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) } + if cmsgs.IP.HasIPv6PacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo) + } + if cmsgs.IP.OriginalDstAddress != nil { space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) } @@ -526,7 +544,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var ts linux.Timeval ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) - cmsgs.IP.Timestamp = ts.ToNsecCapped() + cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true i += bits.AlignUp(length, width) diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go index 7e28a0cef..1b04e1bbc 100644 --- a/pkg/sentry/socket/control/control_test.go +++ b/pkg/sentry/socket/control/control_test.go @@ -50,7 +50,7 @@ func TestParse(t *testing.T) { want := socket.ControlMessages{ IP: socket.IPControlMessages{ HasTimestamp: true, - Timestamp: ts.ToNsecCapped(), + Timestamp: ts.ToTime(), }, } if diff := cmp.Diff(want, cmsg); diff != "" { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 3950caa0f..4ea89f9d0 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -38,7 +38,6 @@ go_library( "//pkg/sentry/socket/control", "//pkg/sentry/vfs", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", "//pkg/usermem", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 38cb2c99c..6e2318f75 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -112,7 +111,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } return readv(s.fd, safemem.IovecsFromBlockSeq(dsts)) })) - return int64(n), err + return n, err } // Write implements fs.FileOperations.Write. @@ -135,7 +134,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } return writev(s.fd, safemem.IovecsFromBlockSeq(srcs)) })) - return int64(n), err + return n, err } // Socket implements socket.Provider.Socket. @@ -181,7 +180,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } @@ -208,7 +207,7 @@ type socketOpsCommon struct { // Release implements fs.FileOperations.Release. func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) - unix.Close(s.fd) + _ = unix.Close(s.fd) } // Readiness implements waiter.Waitable.Readiness. @@ -219,13 +218,13 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { // EventRegister implements waiter.Waitable.EventRegister. func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(s.fd)) + _ = fdnotifier.UpdateFD(int32(s.fd)) } // EventUnregister implements waiter.Waitable.EventUnregister. func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(s.fd)) + _ = fdnotifier.UpdateFD(int32(s.fd)) } // Connect implements socket.Socket.Connect. @@ -288,7 +287,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC) if blocking { var ch chan struct{} - for syscallErr == syserror.ErrWouldBlock { + for syscallErr == linuxerr.ErrWouldBlock { if ch != nil { if syscallErr = t.Block(ch); syscallErr != nil { break @@ -317,7 +316,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, if kernel.VFS2Enabled { f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK)) if err != nil { - unix.Close(fd) + _ = unix.Close(fd) return 0, nil, 0, err } defer f.DecRef(t) @@ -329,7 +328,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, } else { f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0) if err != nil { - unix.Close(fd) + _ = unix.Close(fd) return 0, nil, 0, err } defer f.DecRef(t) @@ -344,7 +343,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, } // Bind implements socket.Socket.Bind. -func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -357,12 +356,12 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Listen implements socket.Socket.Listen. -func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error { return syserr.FromError(unix.Listen(s.fd, backlog)) } // Shutdown implements socket.Socket.Shutdown. -func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error { switch how { case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR: return syserr.FromError(unix.Shutdown(s.fd, how)) @@ -372,7 +371,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, _ hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -402,7 +401,7 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr case linux.TCP_NODELAY: optlen = sizeofInt32 case linux.TCP_INFO: - optlen = int(linux.SizeOfTCPInfo) + optlen = linux.SizeOfTCPInfo } } @@ -535,7 +534,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err := copyToDst() // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. if flags&(unix.MSG_DONTWAIT|unix.MSG_ERRQUEUE) == 0 { - for err == syserror.ErrWouldBlock { + for err == linuxerr.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. if n != 0 { @@ -580,7 +579,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) - controlMessages.IP.Timestamp = ts.ToNsecCapped() + controlMessages.IP.Timestamp = ts.ToTime() } case linux.SOL_IP: @@ -707,7 +706,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var ch chan struct{} n, err := src.CopyInTo(t, sendmsgFromBlocks) if flags&unix.MSG_DONTWAIT == 0 { - for err == syserror.ErrWouldBlock { + for err == linuxerr.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. if n != 0 { @@ -716,7 +715,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if ch != nil { if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -735,7 +734,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b func translateIOSyscallError(err error) error { if err == unix.EAGAIN || err == unix.EWOULDBLOCK { - return syserror.ErrWouldBlock + return linuxerr.ErrWouldBlock } return err } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index ea56f39c1..b9c15daab 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index ed85404da..9710a15ee 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,7 +36,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 5c3ae26f8..ed5fa9c38 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -39,7 +39,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -530,7 +529,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } - if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if n, err := doRead(); err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC @@ -548,7 +547,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := doRead(); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != linuxerr.ErrWouldBlock { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e828982eb..075f61cda 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "device.go", "netstack.go", + "netstack_state.go", "netstack_vfs2.go", "provider.go", "provider_vfs2.go", @@ -42,13 +43,13 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/header", "//pkg/tcpip/link/tun", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/usermem", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9b844b0c0..030c6c8e4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -56,12 +56,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -275,6 +274,7 @@ var Metrics = tcpip.Stats{ ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."), SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."), + SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), @@ -379,9 +379,9 @@ type socketOpsCommon struct { // timestampValid indicates whether timestamp for SIOCGSTAMP has been // set. It is protected by readMu. timestampValid bool - // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only + // timestamp holds the timestamp to use with SIOCTSTAMP. It is only // valid when timestampValid is true. It is protected by readMu. - timestampNS int64 + timestamp time.Time `state:".(int64)"` // TODO(b/153685824): Move this to SocketOptions. // sockOptInq corresponds to TCP_INQ. @@ -411,13 +411,25 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() -// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the -// netstack representation taking any addresses into account. -func bytesToIPAddress(addr []byte) tcpip.Address { - if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { - return "" +// minSockAddrLen returns the minimum length in bytes of a socket address for +// the socket's family. +func (s *socketOpsCommon) minSockAddrLen() int { + const addressFamilySize = 2 + + switch s.family { + case linux.AF_UNIX: + return addressFamilySize + case linux.AF_INET: + return sockAddrInetSize + case linux.AF_INET6: + return sockAddrInet6Size + case linux.AF_PACKET: + return sockAddrLinkSize + case linux.AF_UNSPEC: + return addressFamilySize + default: + panic(fmt.Sprintf("s.family unrecognized = %d", s.family)) } - return tcpip.Address(addr) } func (s *socketOpsCommon) isPacketBased() bool { @@ -448,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { t := kernel.TaskFromContext(ctx) start := t.Kernel().MonotonicClock().Now() deadline := start.Add(v.Timeout) - t.BlockWithDeadline(ch, true, deadline) + _ = t.BlockWithDeadline(ch, true, deadline) } } @@ -459,7 +471,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) if err == syserr.ErrWouldBlock { - return int64(n), syserror.ErrWouldBlock + return int64(n), linuxerr.ErrWouldBlock } if err != nil { return 0, err.ToError() @@ -468,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // WriteTo implements fs.FileOperations.WriteTo. -func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { +func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { s.readMu.Lock() defer s.readMu.Unlock() @@ -492,14 +504,14 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO r := src.Reader(ctx) n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if _, ok := err.(*tcpip.ErrWouldBlock); ok { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } if n < src.NumBytes() { - return n, syserror.ErrWouldBlock + return n, linuxerr.ErrWouldBlock } return n, nil @@ -523,7 +535,7 @@ func (l *limitedPayloader) Len() int { } // ReadFrom implements fs.FileOperations.ReadFrom. -func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { +func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := limitedPayloader{ inner: io.LimitedReader{ R: r, @@ -546,16 +558,21 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.Endpoint.Readiness(mask) } -func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { +// checkFamily returns true iff the specified address family may be used with +// the socket. +// +// If exact is true, then the specified address family must be an exact match +// with the socket's family. +func (s *socketOpsCommon) checkFamily(family uint16, exact bool) bool { if family == uint16(s.family) { - return nil + return true } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { if !s.Endpoint.SocketOptions().GetV6Only() { - return nil + return true } } - return syserr.ErrInvalidArgument + return false } // mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the @@ -588,8 +605,8 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool return syserr.TranslateNetstackError(err) } - if err := s.checkFamily(family, false /* exact */); err != nil { - return err + if !s.checkFamily(family, false /* exact */) { + return syserr.ErrInvalidArgument } addr = s.mapFamily(addr, family) @@ -629,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < 2 { return syserr.ErrInvalidArgument } @@ -647,23 +664,24 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) - if a.Protocol != uint16(s.protocol) { - return syserr.ErrInvalidArgument - } - addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: socket.Ntohs(a.Protocol), } } else { + if s.minSockAddrLen() > len(sockaddr) { + return syserr.ErrInvalidArgument + } + var err *syserr.Error addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } - if err = s.checkFamily(family, true /* exact */); err != nil { - return err + if !s.checkFamily(family, true /* exact */) { + return syserr.ErrAddressFamilyNotSupported } addr = s.mapFamily(addr, family) @@ -688,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Listen implements the linux syscall listen(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog)) } @@ -779,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) { // Shutdown implements the linux syscall shutdown(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error { f, err := ConvertShutdown(how) if err != nil { return err @@ -860,7 +878,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1345,6 +1363,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) return &v, nil + case linux.IPV6_RECVPKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo())) + return &v, nil + case linux.IP6T_ORIGINAL_DST: if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument @@ -1368,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true) + info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true) if err != nil { return nil, err } @@ -1388,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen) + entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } @@ -1408,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) @@ -1425,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) { if _, ok := ep.(tcpip.Endpoint); !ok { log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) return nil, syserr.ErrUnknownProtocolOption @@ -1565,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false) + info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false) if err != nil { return nil, err } @@ -1585,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } - entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen) + entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } @@ -1605,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return nil, syserr.ErrNoDevice } ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber) @@ -2046,7 +2072,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { return syserr.ErrInvalidEndpointState - } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial { return syserr.ErrInvalidEndpointState } @@ -2101,6 +2127,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) return nil + case linux.IPV6_RECVPKTINFO: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(hostarch.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2143,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true) + return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true) case linux.IP6T_SO_SET_ADD_COUNTERS: log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported") @@ -2386,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.ErrProtocolNotAvailable } - stack := inet.StackFromContext(t) - if stack == nil { + stk := inet.StackFromContext(t) + if stk == nil { return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false) + return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false) case linux.IPT_SO_SET_ADD_COUNTERS: log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported") @@ -2490,7 +2525,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, linux.IPV6_RECVPATHMTU, - linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, @@ -2559,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2571,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2716,6 +2750,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr TClass: readCM.TClass, HasIPPacketInfo: readCM.HasIPPacketInfo, PacketInfo: readCM.PacketInfo, + HasIPv6PacketInfo: readCM.HasIPv6PacketInfo, + IPv6PacketInfo: readCM.IPv6PacketInfo, OriginalDstAddress: readCM.OriginalDstAddress, SockErr: readCM.SockErr, }, @@ -2730,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true - s.timestampNS = cm.Timestamp + s.timestamp = cm.Timestamp } } @@ -2789,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { if flags&linux.MSG_ERRQUEUE != 0 { return s.recvErr(t, dst) } @@ -2873,8 +2909,8 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if err != nil { return 0, err } - if err := s.checkFamily(family, false /* exact */); err != nil { - return 0, err + if !s.checkFamily(family, false /* exact */) { + return 0, syserr.ErrInvalidArgument } addrBuf = s.mapFamily(addrBuf, family) @@ -2951,10 +2987,10 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { - return 0, syserror.ENOENT + return 0, linuxerr.ENOENT } - tv := linux.NsecToTimeval(s.timestampNS) + tv := linux.NsecToTimeval(s.timestamp.UnixNano()) _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err @@ -3061,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // interfaceIoctl implements interface requests. -func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error { +func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error { var ( iface inet.Interface index int32 @@ -3069,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe ) // Find the relevant device. - stack := inet.StackFromContext(ctx) - if stack == nil { + stk := inet.StackFromContext(ctx) + if stk == nil { return syserr.ErrNoDevice } @@ -3080,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4])) - if iface, ok := stack.Interfaces()[index]; ok { + if iface, ok := stk.Interfaces()[index]; ok { ifr.SetName(iface.Name) return nil } @@ -3088,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // Find the relevant device. - for index, iface = range stack.Interfaces() { + for index, iface = range stk.Interfaces() { if iface.Name == ifr.Name() { found = true break @@ -3121,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } case linux.SIOCGIFFLAGS: - f, err := interfaceStatusFlags(stack, iface.Name) + f, err := interfaceStatusFlags(stk, iface.Name) if err != nil { return err } @@ -3131,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFADDR: // Copy the IPv4 address out. - for _, addr := range stack.InterfaceAddrs()[index] { + for _, addr := range stk.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -3167,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case linux.SIOCGIFNETMASK: // Gets the network mask of a device. - for _, addr := range stack.InterfaceAddrs()[index] { + for _, addr := range stk.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. if addr.Family != linux.AF_INET { continue @@ -3199,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. - stack := inet.StackFromContext(ctx) - if stack == nil { + stk := inet.StackFromContext(ctx) + if stk == nil { return syserr.ErrNoDevice.ToError() } if ifc.Ptr == 0 { - ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq) + ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq) return nil } max := ifc.Len ifc.Len = 0 - for key, ifaceAddrs := range stack.InterfaceAddrs() { - iface := stack.Interfaces()[key] + for key, ifaceAddrs := range stk.InterfaceAddrs() { + iface := stk.Interfaces()[key] for _, ifaceAddr := range ifaceAddrs { // Don't write past the end of the buffer. if ifc.Len+int32(linux.SizeOfIFReq) > max { @@ -3332,10 +3368,10 @@ func (s *socketOpsCommon) State() uint32 { } case isUDPSocket(s.skType, s.protocol): // UDP socket. - switch udp.EndpointState(s.Endpoint.State()) { - case udp.StateInitial, udp.StateBound, udp.StateClosed: + switch transport.DatagramEndpointState(s.Endpoint.State()) { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed: return linux.TCP_CLOSE - case udp.StateConnected: + case transport.DatagramEndpointStateConnected: return linux.TCP_ESTABLISHED default: return 0 diff --git a/pkg/sentry/socket/netstack/netstack_state.go b/pkg/sentry/socket/netstack/netstack_state.go new file mode 100644 index 000000000..591e00d42 --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_state.go @@ -0,0 +1,31 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "time" +) + +func (s *socketOpsCommon) saveTimestamp() int64 { + s.readMu.Lock() + defer s.readMu.Unlock() + return s.timestamp.UnixNano() +} + +func (s *socketOpsCommon) loadTimestamp(nsec int64) { + s.readMu.Lock() + defer s.readMu.Unlock() + s.timestamp = time.Unix(0, nsec) +} diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index edc160b1b..3cdf29b80 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -113,7 +112,7 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs. } n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) if err == syserr.ErrWouldBlock { - return int64(n), syserror.ErrWouldBlock + return int64(n), linuxerr.ErrWouldBlock } if err != nil { return 0, err.ToError() @@ -132,14 +131,14 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs r := src.Reader(ctx) n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if _, ok := err.(*tcpip.ErrWouldBlock); ok { - return 0, syserror.ErrWouldBlock + return 0, linuxerr.ErrWouldBlock } if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } if n < src.NumBytes() { - return n, syserror.ErrWouldBlock + return n, linuxerr.ErrWouldBlock } return n, nil diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 208ab9909..ea199f223 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // Attach address to interface. nicID := tcpip.NICID(idx) - if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { return syserr.TranslateNetstackError(err).ToError() } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 658e90bb9..d4b80a39d 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "sync/atomic" + "time" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -51,8 +52,19 @@ type ControlMessages struct { func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { var p linux.ControlMessageIPPacketInfo p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + copy(p.LocalAddr[:], packetInfo.LocalAddr) + copy(p.DestinationAddr[:], packetInfo.DestinationAddr) + return p +} + +// ipv6PacketInfoToLinux converts IPv6PacketInfo from tcpip format to Linux +// format. +func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo { + var p linux.ControlMessageIPv6PacketInfo + if n := copy(p.Addr[:], packetInfo.Addr); n != len(p.Addr) { + panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr))) + } + p.NIC = uint32(packetInfo.NIC) return p } @@ -114,7 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa if cmgs.HasOriginalDstAddress { orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) } - return IPControlMessages{ + cm := IPControlMessages{ HasTimestamp: cmgs.HasTimestamp, Timestamp: cmgs.Timestamp, HasInq: cmgs.HasInq, @@ -125,9 +137,16 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa TClass: cmgs.TClass, HasIPPacketInfo: cmgs.HasIPPacketInfo, PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + HasIPv6PacketInfo: cmgs.HasIPv6PacketInfo, OriginalDstAddress: orgDstAddr, SockErr: sockErrCmsgToLinux(cmgs.SockErr), } + + if cm.HasIPv6PacketInfo { + cm.IPv6PacketInfo = ipv6PacketInfoToLinux(cmgs.IPv6PacketInfo) + } + + return cm } // IPControlMessages contains socket control messages for IP sockets. @@ -138,9 +157,9 @@ type IPControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packet used to create - // the read data was received. - Timestamp int64 + // Timestamp is the time that the last packet used to create the read data + // was received. + Timestamp time.Time `state:".(int64)"` // HasInq indicates whether Inq is valid/set. HasInq bool @@ -166,6 +185,12 @@ type IPControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo linux.ControlMessageIPPacketInfo + // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set. + HasIPv6PacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + IPv6PacketInfo linux.ControlMessageIPv6PacketInfo + // OriginalDestinationAddress holds the original destination address // and port of the incoming packet. OriginalDstAddress linux.SockAddr @@ -743,6 +768,8 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) + // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have + // an ethernet address. if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } @@ -750,6 +777,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: Ntohs(a.Protocol), }, family, nil case linux.AF_UNSPEC: diff --git a/pkg/sentry/socket/socket_state.go b/pkg/sentry/socket/socket_state.go new file mode 100644 index 000000000..32e12b238 --- /dev/null +++ b/pkg/sentry/socket/socket_state.go @@ -0,0 +1,27 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socket + +import ( + "time" +) + +func (i *IPControlMessages) saveTimestamp() int64 { + return i.Timestamp.UnixNano() +} + +func (i *IPControlMessages) loadTimestamp(nsec int64) { + i.Timestamp = time.Unix(0, nsec) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5c3cdef6a..7b546c04d 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -62,7 +62,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserr", - "//pkg/syserror", "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 33f9aeb06..b3f0cf563 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -129,9 +129,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv stype: stype, } + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -406,14 +406,15 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { // Accept accepts a new connection. func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() - defer e.Unlock() if !e.Listening() { + e.Unlock() return nil, syserr.ErrInvalidEndpointState } select { case ne := <-e.acceptedChan: + e.Unlock() if peerAddr != nil { ne.Lock() c := ne.connected @@ -429,6 +430,7 @@ func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *s return ne, nil default: + e.Unlock() // Nothing left. return nil, syserr.ErrWouldBlock } @@ -517,3 +519,6 @@ func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { } return v } + +// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters. +func (e *connectionedEndpoint) WakeupWriters() {} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 61338728a..61311718e 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,9 +44,9 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -227,3 +227,6 @@ func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { } return v } + +// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters. +func (e *connectionlessEndpoint) WakeupWriters() {} diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e4de44498..188ad3bd9 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -59,12 +59,14 @@ func (q *queue) Close() { // q.WriterQueue.Notify(waiter.WritableEvents) func (q *queue) Reset(ctx context.Context) { q.mu.Lock() - for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release(ctx) - } + dataList := q.dataList q.dataList.Reset() q.used = 0 q.mu.Unlock() + + for cur := dataList.Front(); cur != nil; cur = cur.Next() { + cur.Release(ctx) + } } // DecRef implements RefCounter.DecRef. @@ -133,7 +135,7 @@ func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, f free := q.limit - q.used if l > free && truncate { - if free == 0 { + if free <= 0 { // Message can't fit right now. q.mu.Unlock() return 0, false, syserr.ErrWouldBlock diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 8ccdadae9..e9e482017 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -38,7 +38,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -494,7 +493,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } n, err := src.CopyInTo(t, &w) - if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { return int(n), syserr.FromError(err) } @@ -514,13 +513,13 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b n, err = src.CopyInTo(t, &w) total += n - if err != syserror.ErrWouldBlock { + if err != linuxerr.ErrWouldBlock { break } if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { - err = syserror.ErrWouldBlock + err = linuxerr.ErrWouldBlock } break } @@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } var total int64 - if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != linuxerr.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := doRead(); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != linuxerr.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { |