diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 8 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 58 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 54 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 79 |
7 files changed, 268 insertions, 21 deletions
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 2b34ef190..5b868216d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,12 +331,12 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: @@ -377,14 +377,14 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index a8ab6b385..460c95b9f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1405,6 +1405,13 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil case linux.IPV6_RECVORIGDSTADDR: if outLen < sizeOfInt32 { @@ -1579,6 +1586,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) return &v, nil + case linux.IP_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -2129,6 +2144,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveTClass(v != 0) return nil + case linux.IPV6_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2317,6 +2342,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetReceiveTOS(v != 0) return nil + case linux.IP_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil + case linux.IP_PKTINFO: if len(optVal) == 0 { return nil @@ -2386,7 +2422,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVTTL, @@ -2462,7 +2497,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_MULTICAST_IF, linux.IPV6_MULTICAST_LOOP, linux.IPV6_RECVDSTOPTS, - linux.IPV6_RECVERR, linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, @@ -2496,7 +2530,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { linux.IP_PKTINFO, linux.IP_PKTOPTIONS, linux.IP_MTU_DISCOVER, - linux.IP_RECVERR, linux.IP_RECVTTL, linux.IP_RECVTOS, linux.IP_MTU, @@ -2798,6 +2831,23 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb(). +func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { + so := s.Endpoint.SocketOptions() + err := so.DequeueErr() + if err == nil { + return nil + } + + // Update socket error to reflect ICMP errors in queue. + if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + so.SetLastError(nextErr.Err) + } else if err.ErrOrigin.IsICMPErr() { + so.SetLastError(nil) + } + return err +} + // addrFamilyFromNetProto returns the address family identifier for the given // network protocol. func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { @@ -2814,7 +2864,7 @@ func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { // recvErr handles MSG_ERRQUEUE for recvmsg(2). // This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - sockErr := s.Endpoint.SocketOptions().DequeueErr() + sockErr := s.dequeueErr() if sockErr == nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 2f13dea6a..1be90d7d5 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -213,3 +214,16 @@ func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { return xsum } + +// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when +// a packet having a `net` header causing an ICMP error. +func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { + switch net { + case IPv4ProtocolNumber: + return tcpip.SockExtErrorOriginICMP + case IPv6ProtocolNumber: + return tcpip.SockExtErrorOriginICMP6 + default: + panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index eb63d735f..095d1734a 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -42,6 +42,9 @@ type SocketOptionsHandler interface { // LastError is invoked when SO_ERROR is read for an endpoint. LastError() *Error + + // UpdateLastError updates the endpoint specific last error field. + UpdateLastError(err *Error) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -70,6 +73,9 @@ func (*DefaultSocketOptionsHandler) LastError() *Error { return nil } +// UpdateLastError implements SocketOptionsHandler.UpdateLastError. +func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -145,6 +151,10 @@ type SocketOptions struct { // the incoming packet should be returned as an ancillary message. receiveOriginalDstAddress uint32 + // recvErrEnabled determines whether extended reliable error message passing + // is enabled. + recvErrEnabled uint32 + // errQueue is the per-socket error queue. It is protected by errQueueMu. errQueueMu sync.Mutex `state:"nosave"` errQueue sockErrorList @@ -171,6 +181,11 @@ func storeAtomicBool(addr *uint32, v bool) { atomic.StoreUint32(addr, val) } +// SetLastError sets the last error for a socket. +func (so *SocketOptions) SetLastError(err *Error) { + so.handler.UpdateLastError(err) +} + // GetBroadcast gets value for SO_BROADCAST option. func (so *SocketOptions) GetBroadcast() bool { return atomic.LoadUint32(&so.broadcastEnabled) != 0 @@ -338,6 +353,19 @@ func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { storeAtomicBool(&so.receiveOriginalDstAddress, v) } +// GetRecvError gets value for IP*_RECVERR option. +func (so *SocketOptions) GetRecvError() bool { + return atomic.LoadUint32(&so.recvErrEnabled) != 0 +} + +// SetRecvError sets value for IP*_RECVERR option. +func (so *SocketOptions) SetRecvError(v bool) { + storeAtomicBool(&so.recvErrEnabled, v) + if !v { + so.pruneErrQueue() + } +} + // GetLastError gets value for SO_ERROR option. func (so *SocketOptions) GetLastError() *Error { return so.handler.LastError() @@ -384,6 +412,11 @@ const ( SockExtErrorOriginICMP6 ) +// IsICMPErr indicates if the error originated from an ICMP error. +func (origin SockErrOrigin) IsICMPErr() bool { + return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 +} + // SockError represents a queue entry in the per-socket error queue. // // +stateify savable @@ -411,6 +444,13 @@ type SockError struct { NetProto NetworkProtocolNumber } +// pruneErrQueue resets the queue. +func (so *SocketOptions) pruneErrQueue() { + so.errQueueMu.Lock() + so.errQueue.Reset() + so.errQueueMu.Unlock() +} + // DequeueErr dequeues a socket extended error from the error queue and returns // it. Returns nil if queue is empty. func (so *SocketOptions) DequeueErr() *SockError { @@ -423,3 +463,32 @@ func (so *SocketOptions) DequeueErr() *SockError { } return err } + +// PeekErr returns the error in the front of the error queue. Returns nil if +// the error queue is empty. +func (so *SocketOptions) PeekErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + return so.errQueue.Front() +} + +// QueueErr inserts the error at the back of the error queue. +// +// Preconditions: so.GetRecvError() == true. +func (so *SocketOptions) QueueErr(err *SockError) { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + so.errQueue.PushBack(err) +} + +// QueueLocalErr queues a local error onto the local queue. +func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { + so.QueueErr(&SockError{ + Err: err, + ErrOrigin: SockExtErrorOriginLocal, + ErrInfo: info, + Payload: payload, + Dst: dst, + NetProto: net, + }) +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 9faab4b9e..e5e247342 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -366,6 +366,13 @@ func (ep *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (ep *endpoint) UpdateLastError(err *tcpip.Error) { + ep.lastErrorMu.Lock() + ep.lastError = err + ep.lastErrorMu.Unlock() +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { return tcpip.ErrNotSupported diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bb0795f78..2128206d7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1303,6 +1303,15 @@ func (e *endpoint) LastError() *tcpip.Error { return e.lastErrorLocked() } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.LockUser() + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + e.UnlockUser() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -2708,6 +2717,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + // Linux passes the payload with the TCP header. We don't know if the TCP + // header even exists, it may not for fragmented packets. + Payload: pkt.Data.ToView(), + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.notifyProtocolGoroutine(notifyError) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { @@ -2722,16 +2766,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNoRoute - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNetworkUnreachable - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 8e16c8435..d919fa011 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -226,6 +226,13 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -511,6 +518,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. + so := e.SocketOptions() + if so.GetRecvError() { + so.QueueLocalErr( + tcpip.ErrMessageTooLong, + route.NetProto, + header.UDPMaximumPacketSize, + tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + Port: dstPort, + }, + v, + ) + } return 0, nil, tcpip.ErrMessageTooLong } @@ -1338,15 +1359,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + // Linux passes the payload without the UDP header. + var payload []byte + udp := header.UDP(pkt.Data.ToView()) + if len(udp) >= header.UDPMinimumSize { + payload = udp.Payload() + } + + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + Payload: payload, + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.waiterQueue.Notify(waiter.EventErr) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { if e.EndpointState() == StateConnected { - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrConnectionRefused - e.lastErrorMu.Unlock() - - e.waiterQueue.Notify(waiter.EventErr) + var errType byte + var errCode byte + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + errType = byte(header.ICMPv4DstUnreachable) + errCode = byte(header.ICMPv4PortUnreachable) + case header.IPv6ProtocolNumber: + errType = byte(header.ICMPv6DstUnreachable) + errCode = byte(header.ICMPv6PortUnreachable) + default: + panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) + } + e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt) return } } |