diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r-- | pkg/sentry/socket/epsocket/BUILD | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 136 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/stack.go | 57 |
3 files changed, 146 insertions, 49 deletions
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 1f014f399..e927821e1 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", @@ -38,6 +39,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e57aed927..635042263 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -43,6 +43,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -290,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// GetAddress reads an sockaddr struct from the given address and converts it -// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 -// addresses. -func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and +// AF_INET6 addresses. +// +// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. +// +// AddressAndFamily returns an address, its family. +func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported + return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } // Get the rest of the fields based on the address family. @@ -309,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } // Drop the terminating NUL (if one exists) and everything after // it for filesystem (non-abstract) addresses. @@ -320,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse } return tcpip.FullAddress{ Addr: tcpip.Address(path), - }, nil + }, family, nil case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -333,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse Addr: bytesToIPAddress(a.Addr[:]), Port: ntohs(a.Port), } - return out, nil + return out, family, nil case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrInvalidArgument + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -349,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse if isLinkLocal(out.Addr) { out.NIC = tcpip.NICID(a.Scope_id) } - return out, nil + return out, family, nil case linux.AF_UNSPEC: - return tcpip.FullAddress{}, nil + return tcpip.FullAddress{}, family, nil default: - return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported } } @@ -428,6 +433,11 @@ func (i *ioSequencePayload) Size() int { return int(i.src.NumBytes()) } +// DropFirst drops the first n bytes from underlying src. +func (i *ioSequencePayload) DropFirst(n int) { + i.src = i.src.DropFirst(int(n)) +} + // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} @@ -476,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr, false /* strict */) + addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) if err != nil { return err } + if family == linux.AF_UNSPEC { + err := s.Endpoint.Disconnect() + if err == tcpip.ErrNotSupported { + return syserr.ErrAddressFamilyNotSupported + } + return syserr.TranslateNetstackError(err) + } // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -509,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr, true /* strict */) + addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -547,7 +564,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, wq, terr := s.Endpoint.Accept() if terr != nil { @@ -574,7 +591,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, ns.SetFlags(flags.Settable()) } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if peerRequested { // Get address of the peer and write it to peer slice. @@ -624,7 +641,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for epsocket.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -655,6 +672,33 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( return val, nil } + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + info, err := netfilter.GetInfo(t, s.Endpoint, outPtr) + if err != nil { + return nil, err + } + return info, nil + + case linux.IPT_SO_GET_ENTRIES: + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + + entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen) + if err != nil { + return nil, err + } + return entries, nil + + } + } + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) } @@ -1028,7 +1072,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(linux.SockAddrInet).Addr, nil + return a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1658,7 +1702,7 @@ func isLinkLocal(addr tcpip.Address) bool { } // ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { switch family { case linux.AF_UNIX: var out linux.SockAddrUnix @@ -1674,15 +1718,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // address length is the max. Abstract and empty paths always return // the full exact length. if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return out, uint32(2 + l) + return &out, uint32(2 + l) } - return out, uint32(3 + l) + return &out, uint32(3 + l) case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1698,7 +1742,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return out, uint32(binary.Size(out)) + return &out, uint32(binary.Size(out)) default: return nil, 0 } @@ -1706,7 +1750,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) { // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1718,7 +1762,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -1791,7 +1835,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1839,7 +1883,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if err == nil { s.updateTimestamp() } - var addr interface{} + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) @@ -1914,7 +1958,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -1990,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to, true /* strict */) + addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) if err != nil { return 0, err } @@ -1998,28 +2042,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] addr = &addrBuf } - v := buffer.NewView(int(src.NumBytes())) - - // Copy all the data into the buffer. - if _, err := src.CopyIn(t, v); err != nil { - return 0, syserr.FromError(err) - } - opts := tcpip.WriteOptions{ To: addr, More: flags&linux.MSG_MORE != 0, EndOfRecord: flags&linux.MSG_EOR != 0, } - n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) + v := &ioSequencePayload{t, src} + n, resCh, err := s.Endpoint.Write(v, opts) if resCh != nil { if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err) } - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) + n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= uintptr(len(v)) || dontWait) { + if err == nil && (n >= int64(v.Size()) || dontWait) { // Complete write. return int(n), nil } @@ -2033,18 +2071,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) - v.TrimFront(int(n)) + v.DropFirst(int(n)) total := n for { - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) - v.TrimFront(int(n)) + n, _, err = s.Endpoint.Write(v, opts) + v.DropFirst(int(n)) total += n if err != nil && err != tcpip.ErrWouldBlock && total == 0 { return 0, syserr.TranslateNetstackError(err) } - if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2252,19 +2290,19 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe case syscall.SIOCGIFMAP: // Gets the hardware parameters of the device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFBRDADDR: // Gets the broadcast address of a device. - // TODO(b/71872867): Implement. + // TODO(gvisor.dev/issue/505): Implement. case syscall.SIOCGIFNETMASK: // Gets the network mask of a device. diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 8fe489c0e..7cf7ff735 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -18,7 +18,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -143,3 +146,57 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { func (s *Stack) Statistics(stat interface{}, arg string) error { return syserr.ErrEndpointOperation.ToError() } + +// RouteTable implements inet.Stack.RouteTable. +func (s *Stack) RouteTable() []inet.Route { + var routeTable []inet.Route + + for _, rt := range s.Stack.GetRouteTable() { + var family uint8 + switch len(rt.Destination.ID()) { + case header.IPv4AddressSize: + family = linux.AF_INET + case header.IPv6AddressSize: + family = linux.AF_INET6 + default: + log.Warningf("Unknown network protocol in route %+v", rt) + continue + } + + routeTable = append(routeTable, inet.Route{ + Family: family, + DstLen: uint8(rt.Destination.Prefix()), // The CIDR prefix for the destination. + + // Always return unspecified protocol since we have no notion of + // protocol for routes. + Protocol: linux.RTPROT_UNSPEC, + // Set statically to LINK scope for now. + // + // TODO(gvisor.dev/issue/595): Set scope for routes. + Scope: linux.RT_SCOPE_LINK, + Type: linux.RTN_UNICAST, + + DstAddr: []byte(rt.Destination.ID()), + OutputInterface: int32(rt.NIC), + GatewayAddr: []byte(rt.Gateway), + }) + } + + return routeTable +} + +// IPTables returns the stack's iptables. +func (s *Stack) IPTables() (iptables.IPTables, error) { + return s.Stack.IPTables(), nil +} + +// FillDefaultIPTables sets the stack's iptables to the default tables, which +// allow and do not modify all traffic. +func (s *Stack) FillDefaultIPTables() { + netfilter.FillDefaultIPTables(s.Stack) +} + +// Resume implements inet.Stack.Resume. +func (s *Stack) Resume() { + s.Stack.Resume() +} |