diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 149 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 106 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/tcpip_state_autogen.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/raw_state_autogen.go | 51 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 229 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 122 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_state_autogen.go | 79 |
15 files changed, 388 insertions, 500 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e8a0103bf..1184acc7a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -260,6 +260,10 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + // LastError implements tcpip.Endpoint.LastError and // transport.Endpoint.LastError. LastError() *tcpip.Error @@ -723,11 +727,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -1226,8 +1226,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn())) - return &v, nil + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1449,19 +1454,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1493,13 +1503,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1520,7 +1525,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1540,7 +1545,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1560,7 +1565,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1582,6 +1587,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1633,13 +1643,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1663,26 +1668,24 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil + + case linux.IP_HDRINCL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -2127,14 +2130,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2173,7 +2193,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2181,7 +2202,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2256,6 +2277,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2317,7 +2343,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2353,7 +2380,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2363,7 +2391,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2373,7 +2402,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2515,7 +2545,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -3384,6 +3413,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3393,7 +3434,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3422,7 +3463,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3432,7 +3473,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 1b1188ee5..cced4d8fc 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -19,10 +19,8 @@ import ( ) // SocketOptionsHandler holds methods that help define endpoint specific -// behavior for socket options. -// These must be implemented by endpoints to: -// - Get notified when socket level options are set. -// - Provide endpoint specific socket options. +// behavior for socket level socket options. These must be implemented by +// endpoints to get notified when socket level options are set. type SocketOptionsHandler interface { // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint. OnReuseAddressSet(v bool) @@ -32,10 +30,6 @@ type SocketOptionsHandler interface { // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint. OnKeepAliveSet(v bool) - - // IsListening is invoked to fetch SO_ACCEPTCONN option value for an - // endpoint. It is used to indicate if the socket is a listening socket. - IsListening() bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -53,11 +47,8 @@ func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {} // OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet. func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {} -// IsListening implements SocketOptionsHandler.IsListening. -func (*DefaultSocketOptionsHandler) IsListening() bool { return false } - -// SocketOptions contains all the variables which store values for SOL_SOCKET -// level options. +// SocketOptions contains all the variables which store values for SOL_SOCKET, +// SOL_IP and SOL_IPV6 level options. // // +stateify savable type SocketOptions struct { @@ -88,6 +79,31 @@ type SocketOptions struct { // keepAliveEnabled determines whether TCP keepalive is enabled for this // socket. keepAliveEnabled uint32 + + // multicastLoopEnabled determines whether multicast packets sent over a + // non-loopback interface will be looped back. Analogous to inet->mc_loop. + multicastLoopEnabled uint32 + + // receiveTOSEnabled is used to specify if the TOS ancillary message is + // passed with incoming packets. + receiveTOSEnabled uint32 + + // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary + // message is passed with incoming packets. + receiveTClassEnabled uint32 + + // receivePacketInfoEnabled is used to specify if more inforamtion is + // provided with incoming packets such as interface index and address. + receivePacketInfoEnabled uint32 + + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets + // being written have an IP header and the endpoint should not attach an IP + // header. + hdrIncludedEnabled uint32 + + // v6OnlyEnabled is used to determine whether an IPv6 socket is to be + // restricted to sending and receiving IPv6 packets only. + v6OnlyEnabled uint32 } // InitHandler initializes the handler. This must be called before using the @@ -167,8 +183,64 @@ func (so *SocketOptions) SetKeepAlive(v bool) { so.handler.OnKeepAliveSet(v) } -// GetAcceptConn gets value for SO_ACCEPTCONN option. -func (so *SocketOptions) GetAcceptConn() bool { - // This option is completely endpoint dependent and unsettable. - return so.handler.IsListening() +// GetMulticastLoop gets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) GetMulticastLoop() bool { + return atomic.LoadUint32(&so.multicastLoopEnabled) != 0 +} + +// SetMulticastLoop sets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) SetMulticastLoop(v bool) { + storeAtomicBool(&so.multicastLoopEnabled, v) +} + +// GetReceiveTOS gets value for IP_RECVTOS option. +func (so *SocketOptions) GetReceiveTOS() bool { + return atomic.LoadUint32(&so.receiveTOSEnabled) != 0 +} + +// SetReceiveTOS sets value for IP_RECVTOS option. +func (so *SocketOptions) SetReceiveTOS(v bool) { + storeAtomicBool(&so.receiveTOSEnabled, v) +} + +// GetReceiveTClass gets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) GetReceiveTClass() bool { + return atomic.LoadUint32(&so.receiveTClassEnabled) != 0 +} + +// SetReceiveTClass sets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) SetReceiveTClass(v bool) { + storeAtomicBool(&so.receiveTClassEnabled, v) +} + +// GetReceivePacketInfo gets value for IP_PKTINFO option. +func (so *SocketOptions) GetReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0 +} + +// SetReceivePacketInfo sets value for IP_PKTINFO option. +func (so *SocketOptions) SetReceivePacketInfo(v bool) { + storeAtomicBool(&so.receivePacketInfoEnabled, v) +} + +// GetHeaderIncluded gets value for IP_HDRINCL option. +func (so *SocketOptions) GetHeaderIncluded() bool { + return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 +} + +// SetHeaderIncluded sets value for IP_HDRINCL option. +func (so *SocketOptions) SetHeaderIncluded(v bool) { + storeAtomicBool(&so.hdrIncludedEnabled, v) +} + +// GetV6Only gets value for IPV6_V6ONLY option. +func (so *SocketOptions) GetV6Only() bool { + return atomic.LoadUint32(&so.v6OnlyEnabled) != 0 +} + +// SetV6Only sets value for IPV6_V6ONLY option. +// +// Preconditions: the backing TCP or UDP endpoint must be in initial state. +func (so *SocketOptions) SetV6Only(v bool) { + storeAtomicBool(&so.v6OnlyEnabled, v) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 6ed00e74f..2eb6e76af 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -718,37 +718,8 @@ const ( // TCP, it determines if the Nagle algorithm is on or off. DelayOption - // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to - // specify whether multicast packets sent over a non-loopback interface - // will be looped back. - MulticastLoopOption - // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. QuickAckOption - - // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to - // specify if the IPV6_TCLASS ancillary message is passed with incoming - // packets. - ReceiveTClassOption - - // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify - // if the TOS ancillary message is passed with incoming packets. - ReceiveTOSOption - - // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to - // specify if more inforamtion is provided with incoming packets such as - // interface index and address. - ReceiveIPPacketInfoOption - - // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify - // whether an IPv6 socket is to be restricted to sending and receiving - // IPv6 packets only. - V6OnlyOption - - // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw - // endpoint that all packets being written have an IP header and the - // endpoint should not attach an IP header. - IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go index f40b40b84..c57c5f61c 100644 --- a/pkg/tcpip/tcpip_state_autogen.go +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -19,6 +19,12 @@ func (so *SocketOptions) StateFields() []string { "reuseAddressEnabled", "reusePortEnabled", "keepAliveEnabled", + "multicastLoopEnabled", + "receiveTOSEnabled", + "receiveTClassEnabled", + "receivePacketInfoEnabled", + "hdrIncludedEnabled", + "v6OnlyEnabled", } } @@ -33,6 +39,12 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(4, &so.reuseAddressEnabled) stateSinkObject.Save(5, &so.reusePortEnabled) stateSinkObject.Save(6, &so.keepAliveEnabled) + stateSinkObject.Save(7, &so.multicastLoopEnabled) + stateSinkObject.Save(8, &so.receiveTOSEnabled) + stateSinkObject.Save(9, &so.receiveTClassEnabled) + stateSinkObject.Save(10, &so.receivePacketInfoEnabled) + stateSinkObject.Save(11, &so.hdrIncludedEnabled) + stateSinkObject.Save(12, &so.v6OnlyEnabled) } func (so *SocketOptions) afterLoad() {} @@ -45,6 +57,12 @@ func (so *SocketOptions) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(4, &so.reuseAddressEnabled) stateSourceObject.Load(5, &so.reusePortEnabled) stateSourceObject.Load(6, &so.keepAliveEnabled) + stateSourceObject.Load(7, &so.multicastLoopEnabled) + stateSourceObject.Load(8, &so.receiveTOSEnabled) + stateSourceObject.Load(9, &so.receiveTClassEnabled) + stateSourceObject.Load(10, &so.receivePacketInfoEnabled) + stateSourceObject.Load(11, &so.hdrIncludedEnabled) + stateSourceObject.Load(12, &so.v6OnlyEnabled) } func (f *FullAddress) StateTypeName() string { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 0a714498d..5eacd8d24 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -148,6 +148,7 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index e2c7a0d62..da402bad9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -549,8 +549,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } +// SetOwner implements tcpip.Endpoint.SetOwner. func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 2b1022995..0478900c3 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -65,7 +65,6 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool - hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -116,9 +115,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, - hdrIncluded: !associated, } e.ops.InitHandler(e) + e.ops.SetHeaderIncluded(!associated) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -271,7 +270,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -361,7 +360,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), }) @@ -538,13 +537,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - e.hdrIncluded = v - e.mu.Unlock() - return nil - } return tcpip.ErrUnknownProtocolOption } @@ -608,16 +600,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - v := e.hdrIncluded - e.mu.Unlock() - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go index 4874e4636..6a216ccc2 100644 --- a/pkg/tcpip/transport/raw/raw_state_autogen.go +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -50,7 +50,6 @@ func (e *endpoint) StateFields() []string { "DefaultSocketOptionsHandler", "waiterQueue", "associated", - "hdrIncluded", "rcvList", "rcvBufSize", "rcvBufSizeMax", @@ -69,23 +68,22 @@ func (e *endpoint) StateFields() []string { func (e *endpoint) StateSave(stateSinkObject state.Sink) { e.beforeSave() var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() - stateSinkObject.SaveValue(7, rcvBufSizeMaxValue) + stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) stateSinkObject.Save(3, &e.associated) - stateSinkObject.Save(4, &e.hdrIncluded) - stateSinkObject.Save(5, &e.rcvList) - stateSinkObject.Save(6, &e.rcvBufSize) - stateSinkObject.Save(8, &e.rcvClosed) - stateSinkObject.Save(9, &e.sndBufSize) - stateSinkObject.Save(10, &e.sndBufSizeMax) - stateSinkObject.Save(11, &e.closed) - stateSinkObject.Save(12, &e.connected) - stateSinkObject.Save(13, &e.bound) - stateSinkObject.Save(14, &e.linger) - stateSinkObject.Save(15, &e.owner) - stateSinkObject.Save(16, &e.ops) + stateSinkObject.Save(4, &e.rcvList) + stateSinkObject.Save(5, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.sndBufSize) + stateSinkObject.Save(9, &e.sndBufSizeMax) + stateSinkObject.Save(10, &e.closed) + stateSinkObject.Save(11, &e.connected) + stateSinkObject.Save(12, &e.bound) + stateSinkObject.Save(13, &e.linger) + stateSinkObject.Save(14, &e.owner) + stateSinkObject.Save(15, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { @@ -93,19 +91,18 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) stateSourceObject.Load(2, &e.waiterQueue) stateSourceObject.Load(3, &e.associated) - stateSourceObject.Load(4, &e.hdrIncluded) - stateSourceObject.Load(5, &e.rcvList) - stateSourceObject.Load(6, &e.rcvBufSize) - stateSourceObject.Load(8, &e.rcvClosed) - stateSourceObject.Load(9, &e.sndBufSize) - stateSourceObject.Load(10, &e.sndBufSizeMax) - stateSourceObject.Load(11, &e.closed) - stateSourceObject.Load(12, &e.connected) - stateSourceObject.Load(13, &e.bound) - stateSourceObject.Load(14, &e.linger) - stateSourceObject.Load(15, &e.owner) - stateSourceObject.Load(16, &e.ops) - stateSourceObject.LoadValue(7, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) + stateSourceObject.Load(4, &e.rcvList) + stateSourceObject.Load(5, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.sndBufSize) + stateSourceObject.Load(9, &e.sndBufSizeMax) + stateSourceObject.Load(10, &e.closed) + stateSourceObject.Load(11, &e.connected) + stateSourceObject.Load(12, &e.bound) + stateSourceObject.Load(13, &e.linger) + stateSourceObject.Load(14, &e.owner) + stateSourceObject.Load(15, &e.ops) + stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) stateSourceObject.AfterLoad(e.afterLoad) } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5f2221f1b..3e1041cbe 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i route.ResolveWith(s.remoteLinkAddr) n := newEndpoint(l.stack, netProto, queue) - n.v6only = l.v6Only + n.ops.SetV6Only(l.v6Only) n.ID = s.id n.boundNICID = s.nicID n.route = route @@ -752,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() - v6Only := e.v6only + v6Only := e.ops.GetV6Only() ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e38488d4d..31eded0ce 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1078,7 +1078,7 @@ func (e *endpoint) transitionToStateCloseLocked() { // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) - if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } @@ -1635,7 +1635,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() } extTW, newSyn := e.rcv.handleTimeWaitSegment(s) if newSyn { - info := e.EndpointInfo.TransportEndpointInfo + info := e.TransportEndpointInfo newID := info.ID newID.RemoteAddress = "" newID.RemotePort = 0 diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 713a70b47..fb64851ae 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -309,18 +309,6 @@ type Stats struct { // marker interface. func (*Stats) IsEndpointStats() {} -// EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. -// -// +stateify savable -type EndpointInfo struct { - stack.TransportEndpointInfo -} - -// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo -// marker interface. -func (*EndpointInfo) IsEndpointInfo() {} - // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -361,7 +349,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // // +stateify savable type endpoint struct { - EndpointInfo + stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the @@ -442,7 +430,6 @@ type endpoint struct { boundNICID tcpip.NICID route *stack.Route `state:"manual"` ttl uint8 - v6only bool isConnectNotified bool // h stores a reference to the current handshake state if the endpoint is in @@ -865,11 +852,9 @@ type keepalive struct { func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ stack: s, - EndpointInfo: EndpointInfo{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.TCPProtocolNumber, - }, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, }, waiterQueue: waiterQueue, state: StateInitial, @@ -888,6 +873,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -1686,21 +1672,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { o = 0 } atomic.StoreUint32(&e.slowAck, o) - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - // We only allow this to be set when we're in the initial state. - if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.LockUser() - e.v6only = v - e.UnlockUser() } return nil @@ -1985,13 +1956,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } -// IsListening implements tcpip.SocketOptionsHandler.IsListening. -func (e *endpoint) IsListening() bool { - e.LockUser() - defer e.UnlockUser() - return e.EndpointState() == StateListen -} - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { @@ -2006,21 +1970,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { v := atomic.LoadUint32(&e.slowAck) == 0 return v, nil - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.LockUser() - v := e.v6only - e.UnlockUser() - - return v, nil - - case tcpip.MulticastLoopOption: - return true, nil - default: return false, tcpip.ErrUnknownProtocolOption } @@ -2182,7 +2131,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -2716,7 +2665,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // v6only set to false. if netProto == header.IPv6ProtocolNumber { stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) - alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4 if alsoBindToV4 { netProtos = append(netProtos, header.IPv4ProtocolNumber) } @@ -3180,7 +3129,7 @@ func (e *endpoint) State() uint32 { func (e *endpoint) Info() tcpip.EndpointInfo { e.LockUser() // Make a copy of the endpoint info. - ret := e.EndpointInfo + ret := e.TransportEndpointInfo e.UnlockUser() return &ret } diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 4f369151b..590602fdf 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -127,36 +127,13 @@ func (r *rcvBufAutoTuneParams) StateLoad(stateSourceObject state.Source) { stateSourceObject.LoadValue(5, new(unixTime), func(y interface{}) { r.loadRttMeasureTime(y.(unixTime)) }) } -func (e *EndpointInfo) StateTypeName() string { - return "pkg/tcpip/transport/tcp.EndpointInfo" -} - -func (e *EndpointInfo) StateFields() []string { - return []string{ - "TransportEndpointInfo", - } -} - -func (e *EndpointInfo) beforeSave() {} - -func (e *EndpointInfo) StateSave(stateSinkObject state.Sink) { - e.beforeSave() - stateSinkObject.Save(0, &e.TransportEndpointInfo) -} - -func (e *EndpointInfo) afterLoad() {} - -func (e *EndpointInfo) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &e.TransportEndpointInfo) -} - func (e *endpoint) StateTypeName() string { return "pkg/tcpip/transport/tcp.endpoint" } func (e *endpoint) StateFields() []string { return []string{ - "EndpointInfo", + "TransportEndpointInfo", "DefaultSocketOptionsHandler", "waiterQueue", "uniqueID", @@ -172,7 +149,6 @@ func (e *endpoint) StateFields() []string { "state", "boundNICID", "ttl", - "v6only", "isConnectNotified", "portFlags", "boundBindToDevice", @@ -234,10 +210,10 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { var stateValue EndpointState = e.saveState() stateSinkObject.SaveValue(13, stateValue) var recentTSTimeValue unixTime = e.saveRecentTSTime() - stateSinkObject.SaveValue(27, recentTSTimeValue) + stateSinkObject.SaveValue(26, recentTSTimeValue) var acceptedChanValue []*endpoint = e.saveAcceptedChan() - stateSinkObject.SaveValue(53, acceptedChanValue) - stateSinkObject.Save(0, &e.EndpointInfo) + stateSinkObject.SaveValue(52, acceptedChanValue) + stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) stateSinkObject.Save(3, &e.uniqueID) @@ -250,58 +226,57 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(12, &e.ownedByUser) stateSinkObject.Save(14, &e.boundNICID) stateSinkObject.Save(15, &e.ttl) - stateSinkObject.Save(16, &e.v6only) - stateSinkObject.Save(17, &e.isConnectNotified) - stateSinkObject.Save(18, &e.portFlags) - stateSinkObject.Save(19, &e.boundBindToDevice) - stateSinkObject.Save(20, &e.boundPortFlags) - stateSinkObject.Save(21, &e.boundDest) - stateSinkObject.Save(22, &e.effectiveNetProtos) - stateSinkObject.Save(23, &e.workerRunning) - stateSinkObject.Save(24, &e.workerCleanup) - stateSinkObject.Save(25, &e.sendTSOk) - stateSinkObject.Save(26, &e.recentTS) - stateSinkObject.Save(28, &e.tsOffset) - stateSinkObject.Save(29, &e.shutdownFlags) - stateSinkObject.Save(30, &e.sackPermitted) - stateSinkObject.Save(31, &e.sack) - stateSinkObject.Save(32, &e.bindToDevice) - stateSinkObject.Save(33, &e.delay) - stateSinkObject.Save(34, &e.cork) - stateSinkObject.Save(35, &e.scoreboard) - stateSinkObject.Save(36, &e.slowAck) - stateSinkObject.Save(37, &e.segmentQueue) - stateSinkObject.Save(38, &e.synRcvdCount) - stateSinkObject.Save(39, &e.userMSS) - stateSinkObject.Save(40, &e.maxSynRetries) - stateSinkObject.Save(41, &e.windowClamp) - stateSinkObject.Save(42, &e.sndBufSize) - stateSinkObject.Save(43, &e.sndBufUsed) - stateSinkObject.Save(44, &e.sndClosed) - stateSinkObject.Save(45, &e.sndBufInQueue) - stateSinkObject.Save(46, &e.sndQueue) - stateSinkObject.Save(47, &e.cc) - stateSinkObject.Save(48, &e.packetTooBigCount) - stateSinkObject.Save(49, &e.sndMTU) - stateSinkObject.Save(50, &e.keepalive) - stateSinkObject.Save(51, &e.userTimeout) - stateSinkObject.Save(52, &e.deferAccept) - stateSinkObject.Save(54, &e.rcv) - stateSinkObject.Save(55, &e.snd) - stateSinkObject.Save(56, &e.connectingAddress) - stateSinkObject.Save(57, &e.amss) - stateSinkObject.Save(58, &e.sendTOS) - stateSinkObject.Save(59, &e.gso) - stateSinkObject.Save(60, &e.tcpLingerTimeout) - stateSinkObject.Save(61, &e.closed) - stateSinkObject.Save(62, &e.txHash) - stateSinkObject.Save(63, &e.owner) - stateSinkObject.Save(64, &e.linger) - stateSinkObject.Save(65, &e.ops) + stateSinkObject.Save(16, &e.isConnectNotified) + stateSinkObject.Save(17, &e.portFlags) + stateSinkObject.Save(18, &e.boundBindToDevice) + stateSinkObject.Save(19, &e.boundPortFlags) + stateSinkObject.Save(20, &e.boundDest) + stateSinkObject.Save(21, &e.effectiveNetProtos) + stateSinkObject.Save(22, &e.workerRunning) + stateSinkObject.Save(23, &e.workerCleanup) + stateSinkObject.Save(24, &e.sendTSOk) + stateSinkObject.Save(25, &e.recentTS) + stateSinkObject.Save(27, &e.tsOffset) + stateSinkObject.Save(28, &e.shutdownFlags) + stateSinkObject.Save(29, &e.sackPermitted) + stateSinkObject.Save(30, &e.sack) + stateSinkObject.Save(31, &e.bindToDevice) + stateSinkObject.Save(32, &e.delay) + stateSinkObject.Save(33, &e.cork) + stateSinkObject.Save(34, &e.scoreboard) + stateSinkObject.Save(35, &e.slowAck) + stateSinkObject.Save(36, &e.segmentQueue) + stateSinkObject.Save(37, &e.synRcvdCount) + stateSinkObject.Save(38, &e.userMSS) + stateSinkObject.Save(39, &e.maxSynRetries) + stateSinkObject.Save(40, &e.windowClamp) + stateSinkObject.Save(41, &e.sndBufSize) + stateSinkObject.Save(42, &e.sndBufUsed) + stateSinkObject.Save(43, &e.sndClosed) + stateSinkObject.Save(44, &e.sndBufInQueue) + stateSinkObject.Save(45, &e.sndQueue) + stateSinkObject.Save(46, &e.cc) + stateSinkObject.Save(47, &e.packetTooBigCount) + stateSinkObject.Save(48, &e.sndMTU) + stateSinkObject.Save(49, &e.keepalive) + stateSinkObject.Save(50, &e.userTimeout) + stateSinkObject.Save(51, &e.deferAccept) + stateSinkObject.Save(53, &e.rcv) + stateSinkObject.Save(54, &e.snd) + stateSinkObject.Save(55, &e.connectingAddress) + stateSinkObject.Save(56, &e.amss) + stateSinkObject.Save(57, &e.sendTOS) + stateSinkObject.Save(58, &e.gso) + stateSinkObject.Save(59, &e.tcpLingerTimeout) + stateSinkObject.Save(60, &e.closed) + stateSinkObject.Save(61, &e.txHash) + stateSinkObject.Save(62, &e.owner) + stateSinkObject.Save(63, &e.linger) + stateSinkObject.Save(64, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &e.EndpointInfo) + stateSourceObject.Load(0, &e.TransportEndpointInfo) stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) stateSourceObject.LoadWait(2, &e.waiterQueue) stateSourceObject.Load(3, &e.uniqueID) @@ -314,59 +289,58 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(12, &e.ownedByUser) stateSourceObject.Load(14, &e.boundNICID) stateSourceObject.Load(15, &e.ttl) - stateSourceObject.Load(16, &e.v6only) - stateSourceObject.Load(17, &e.isConnectNotified) - stateSourceObject.Load(18, &e.portFlags) - stateSourceObject.Load(19, &e.boundBindToDevice) - stateSourceObject.Load(20, &e.boundPortFlags) - stateSourceObject.Load(21, &e.boundDest) - stateSourceObject.Load(22, &e.effectiveNetProtos) - stateSourceObject.Load(23, &e.workerRunning) - stateSourceObject.Load(24, &e.workerCleanup) - stateSourceObject.Load(25, &e.sendTSOk) - stateSourceObject.Load(26, &e.recentTS) - stateSourceObject.Load(28, &e.tsOffset) - stateSourceObject.Load(29, &e.shutdownFlags) - stateSourceObject.Load(30, &e.sackPermitted) - stateSourceObject.Load(31, &e.sack) - stateSourceObject.Load(32, &e.bindToDevice) - stateSourceObject.Load(33, &e.delay) - stateSourceObject.Load(34, &e.cork) - stateSourceObject.Load(35, &e.scoreboard) - stateSourceObject.Load(36, &e.slowAck) - stateSourceObject.LoadWait(37, &e.segmentQueue) - stateSourceObject.Load(38, &e.synRcvdCount) - stateSourceObject.Load(39, &e.userMSS) - stateSourceObject.Load(40, &e.maxSynRetries) - stateSourceObject.Load(41, &e.windowClamp) - stateSourceObject.Load(42, &e.sndBufSize) - stateSourceObject.Load(43, &e.sndBufUsed) - stateSourceObject.Load(44, &e.sndClosed) - stateSourceObject.Load(45, &e.sndBufInQueue) - stateSourceObject.LoadWait(46, &e.sndQueue) - stateSourceObject.Load(47, &e.cc) - stateSourceObject.Load(48, &e.packetTooBigCount) - stateSourceObject.Load(49, &e.sndMTU) - stateSourceObject.Load(50, &e.keepalive) - stateSourceObject.Load(51, &e.userTimeout) - stateSourceObject.Load(52, &e.deferAccept) - stateSourceObject.LoadWait(54, &e.rcv) - stateSourceObject.LoadWait(55, &e.snd) - stateSourceObject.Load(56, &e.connectingAddress) - stateSourceObject.Load(57, &e.amss) - stateSourceObject.Load(58, &e.sendTOS) - stateSourceObject.Load(59, &e.gso) - stateSourceObject.Load(60, &e.tcpLingerTimeout) - stateSourceObject.Load(61, &e.closed) - stateSourceObject.Load(62, &e.txHash) - stateSourceObject.Load(63, &e.owner) - stateSourceObject.Load(64, &e.linger) - stateSourceObject.Load(65, &e.ops) + stateSourceObject.Load(16, &e.isConnectNotified) + stateSourceObject.Load(17, &e.portFlags) + stateSourceObject.Load(18, &e.boundBindToDevice) + stateSourceObject.Load(19, &e.boundPortFlags) + stateSourceObject.Load(20, &e.boundDest) + stateSourceObject.Load(21, &e.effectiveNetProtos) + stateSourceObject.Load(22, &e.workerRunning) + stateSourceObject.Load(23, &e.workerCleanup) + stateSourceObject.Load(24, &e.sendTSOk) + stateSourceObject.Load(25, &e.recentTS) + stateSourceObject.Load(27, &e.tsOffset) + stateSourceObject.Load(28, &e.shutdownFlags) + stateSourceObject.Load(29, &e.sackPermitted) + stateSourceObject.Load(30, &e.sack) + stateSourceObject.Load(31, &e.bindToDevice) + stateSourceObject.Load(32, &e.delay) + stateSourceObject.Load(33, &e.cork) + stateSourceObject.Load(34, &e.scoreboard) + stateSourceObject.Load(35, &e.slowAck) + stateSourceObject.LoadWait(36, &e.segmentQueue) + stateSourceObject.Load(37, &e.synRcvdCount) + stateSourceObject.Load(38, &e.userMSS) + stateSourceObject.Load(39, &e.maxSynRetries) + stateSourceObject.Load(40, &e.windowClamp) + stateSourceObject.Load(41, &e.sndBufSize) + stateSourceObject.Load(42, &e.sndBufUsed) + stateSourceObject.Load(43, &e.sndClosed) + stateSourceObject.Load(44, &e.sndBufInQueue) + stateSourceObject.LoadWait(45, &e.sndQueue) + stateSourceObject.Load(46, &e.cc) + stateSourceObject.Load(47, &e.packetTooBigCount) + stateSourceObject.Load(48, &e.sndMTU) + stateSourceObject.Load(49, &e.keepalive) + stateSourceObject.Load(50, &e.userTimeout) + stateSourceObject.Load(51, &e.deferAccept) + stateSourceObject.LoadWait(53, &e.rcv) + stateSourceObject.LoadWait(54, &e.snd) + stateSourceObject.Load(55, &e.connectingAddress) + stateSourceObject.Load(56, &e.amss) + stateSourceObject.Load(57, &e.sendTOS) + stateSourceObject.Load(58, &e.gso) + stateSourceObject.Load(59, &e.tcpLingerTimeout) + stateSourceObject.Load(60, &e.closed) + stateSourceObject.Load(61, &e.txHash) + stateSourceObject.Load(62, &e.owner) + stateSourceObject.Load(63, &e.linger) + stateSourceObject.Load(64, &e.ops) stateSourceObject.LoadValue(4, new(string), func(y interface{}) { e.loadHardError(y.(string)) }) stateSourceObject.LoadValue(5, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) stateSourceObject.LoadValue(13, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) - stateSourceObject.LoadValue(27, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) - stateSourceObject.LoadValue(53, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) + stateSourceObject.LoadValue(26, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) + stateSourceObject.LoadValue(52, new([]*endpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*endpoint)) }) stateSourceObject.AfterLoad(e.afterLoad) } @@ -1031,7 +1005,6 @@ func init() { state.Register((*cubicState)(nil)) state.Register((*SACKInfo)(nil)) state.Register((*rcvBufAutoTuneParams)(nil)) - state.Register((*EndpointInfo)(nil)) state.Register((*endpoint)(nil)) state.Register((*keepalive)(nil)) state.Register((*rackControl)(nil)) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9d33a694b..a9c74148b 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -101,12 +101,10 @@ type endpoint struct { state EndpointState route *stack.Route `state:"manual"` dstPort uint16 - v6only bool ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID - multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID @@ -122,17 +120,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - - // receiveTClass determines if the incoming IPv6 TClass header field is - // passed as ancillary data to ControlMessages on Read. - receiveTClass bool - - // receiveIPPacketInfo determines if the packet info is returned by Read. - receiveIPPacketInfo bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -188,7 +175,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), @@ -196,6 +182,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue uniqueID: s.UniqueID(), } e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -307,21 +294,16 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess HasTimestamp: true, Timestamp: p.timestamp, } - e.mu.RLock() - receiveTOS := e.receiveTOS - receiveTClass := e.receiveTClass - receiveIPPacketInfo := e.receiveIPPacketInfo - e.mu.RUnlock() - if receiveTOS { + if e.ops.GetReceiveTOS() { cm.HasTOS = true cm.TOS = p.tos } - if receiveTClass { + if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } - if receiveIPPacketInfo { + if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } @@ -388,7 +370,7 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) if err != nil { return nil, 0, err } @@ -595,48 +577,6 @@ func (e *endpoint) OnReusePortSet(v bool) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = v - e.mu.Unlock() - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = v - e.mu.Unlock() - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrNotSupported - } - - e.mu.Lock() - e.receiveTClass = v - e.mu.Unlock() - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v - } return nil } @@ -851,51 +791,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTOSOption: - e.mu.RLock() - v := e.receiveTOS - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrNotSupported - } - - e.mu.RLock() - v := e.receiveTClass - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.mu.RLock() - v := e.v6only - e.mu.RUnlock() - - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -1036,7 +932,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -1147,7 +1043,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -1259,7 +1155,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // wildcard (empty) address, and this is an IPv6 endpoint with v6only // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv6ProtocolNumber, header.IPv4ProtocolNumber, diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 99f3fc37f..9d06035ea 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -114,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index cb05b5cd8..dc6afdff9 100644 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -65,21 +65,16 @@ func (e *endpoint) StateFields() []string { "sndBufSizeMax", "state", "dstPort", - "v6only", "ttl", "multicastTTL", "multicastAddr", "multicastNICID", - "multicastLoop", "portFlags", "bindToDevice", "lastError", "boundBindToDevice", "boundPortFlags", "sendTOS", - "receiveTOS", - "receiveTClass", - "receiveIPPacketInfo", "shutdownFlags", "multicastMemberships", "effectiveNetProtos", @@ -94,7 +89,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() stateSinkObject.SaveValue(6, rcvBufSizeMaxValue) var lastErrorValue string = e.saveLastError() - stateSinkObject.SaveValue(21, lastErrorValue) + stateSinkObject.SaveValue(19, lastErrorValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) stateSinkObject.Save(2, &e.waiterQueue) @@ -107,26 +102,21 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(10, &e.sndBufSizeMax) stateSinkObject.Save(11, &e.state) stateSinkObject.Save(12, &e.dstPort) - stateSinkObject.Save(13, &e.v6only) - stateSinkObject.Save(14, &e.ttl) - stateSinkObject.Save(15, &e.multicastTTL) - stateSinkObject.Save(16, &e.multicastAddr) - stateSinkObject.Save(17, &e.multicastNICID) - stateSinkObject.Save(18, &e.multicastLoop) - stateSinkObject.Save(19, &e.portFlags) - stateSinkObject.Save(20, &e.bindToDevice) - stateSinkObject.Save(22, &e.boundBindToDevice) - stateSinkObject.Save(23, &e.boundPortFlags) - stateSinkObject.Save(24, &e.sendTOS) - stateSinkObject.Save(25, &e.receiveTOS) - stateSinkObject.Save(26, &e.receiveTClass) - stateSinkObject.Save(27, &e.receiveIPPacketInfo) - stateSinkObject.Save(28, &e.shutdownFlags) - stateSinkObject.Save(29, &e.multicastMemberships) - stateSinkObject.Save(30, &e.effectiveNetProtos) - stateSinkObject.Save(31, &e.owner) - stateSinkObject.Save(32, &e.linger) - stateSinkObject.Save(33, &e.ops) + stateSinkObject.Save(13, &e.ttl) + stateSinkObject.Save(14, &e.multicastTTL) + stateSinkObject.Save(15, &e.multicastAddr) + stateSinkObject.Save(16, &e.multicastNICID) + stateSinkObject.Save(17, &e.portFlags) + stateSinkObject.Save(18, &e.bindToDevice) + stateSinkObject.Save(20, &e.boundBindToDevice) + stateSinkObject.Save(21, &e.boundPortFlags) + stateSinkObject.Save(22, &e.sendTOS) + stateSinkObject.Save(23, &e.shutdownFlags) + stateSinkObject.Save(24, &e.multicastMemberships) + stateSinkObject.Save(25, &e.effectiveNetProtos) + stateSinkObject.Save(26, &e.owner) + stateSinkObject.Save(27, &e.linger) + stateSinkObject.Save(28, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { @@ -142,28 +132,23 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(10, &e.sndBufSizeMax) stateSourceObject.Load(11, &e.state) stateSourceObject.Load(12, &e.dstPort) - stateSourceObject.Load(13, &e.v6only) - stateSourceObject.Load(14, &e.ttl) - stateSourceObject.Load(15, &e.multicastTTL) - stateSourceObject.Load(16, &e.multicastAddr) - stateSourceObject.Load(17, &e.multicastNICID) - stateSourceObject.Load(18, &e.multicastLoop) - stateSourceObject.Load(19, &e.portFlags) - stateSourceObject.Load(20, &e.bindToDevice) - stateSourceObject.Load(22, &e.boundBindToDevice) - stateSourceObject.Load(23, &e.boundPortFlags) - stateSourceObject.Load(24, &e.sendTOS) - stateSourceObject.Load(25, &e.receiveTOS) - stateSourceObject.Load(26, &e.receiveTClass) - stateSourceObject.Load(27, &e.receiveIPPacketInfo) - stateSourceObject.Load(28, &e.shutdownFlags) - stateSourceObject.Load(29, &e.multicastMemberships) - stateSourceObject.Load(30, &e.effectiveNetProtos) - stateSourceObject.Load(31, &e.owner) - stateSourceObject.Load(32, &e.linger) - stateSourceObject.Load(33, &e.ops) + stateSourceObject.Load(13, &e.ttl) + stateSourceObject.Load(14, &e.multicastTTL) + stateSourceObject.Load(15, &e.multicastAddr) + stateSourceObject.Load(16, &e.multicastNICID) + stateSourceObject.Load(17, &e.portFlags) + stateSourceObject.Load(18, &e.bindToDevice) + stateSourceObject.Load(20, &e.boundBindToDevice) + stateSourceObject.Load(21, &e.boundPortFlags) + stateSourceObject.Load(22, &e.sendTOS) + stateSourceObject.Load(23, &e.shutdownFlags) + stateSourceObject.Load(24, &e.multicastMemberships) + stateSourceObject.Load(25, &e.effectiveNetProtos) + stateSourceObject.Load(26, &e.owner) + stateSourceObject.Load(27, &e.linger) + stateSourceObject.Load(28, &e.ops) stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) - stateSourceObject.LoadValue(21, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) + stateSourceObject.LoadValue(19, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) stateSourceObject.AfterLoad(e.afterLoad) } |