summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go149
-rw-r--r--pkg/tcpip/socketops.go106
-rw-r--r--pkg/tcpip/stack/ndp_test.go12
-rw-r--r--pkg/tcpip/stack/transport_test.go1
-rw-r--r--pkg/tcpip/tcpip.go29
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go1
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go25
-rw-r--r--pkg/tcpip/transport/tcp/accept.go4
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go67
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go122
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go55
17 files changed, 250 insertions, 359 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e8a0103bf..1184acc7a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -260,6 +260,10 @@ type commonEndpoint interface {
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
// LastError implements tcpip.Endpoint.LastError and
// transport.Endpoint.LastError.
LastError() *tcpip.Error
@@ -723,11 +727,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
return nil
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
- v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return syserr.TranslateNetstackError(err)
- }
- if !v {
+ if !s.Endpoint.SocketOptions().GetV6Only() {
return nil
}
}
@@ -1226,8 +1226,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn()))
- return &v, nil
+ // This option is only viable for TCP endpoints.
+ var v bool
+ if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) {
+ v = tcp.EndpointState(ep.State()) == tcp.StateListen
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1449,19 +1454,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, _ := s.Type()
+ if family != linux.AF_INET6 {
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only()))
+ return &v, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1493,13 +1503,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1520,7 +1525,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1540,7 +1545,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1560,7 +1565,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1582,6 +1587,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// getSockOptIP implements GetSockOpt when level is SOL_IP.
func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1633,13 +1643,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop()))
+ return &v, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
@@ -1663,26 +1668,24 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
+ return &v, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo()))
+ return &v, nil
+
+ case linux.IP_HDRINCL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
+ return &v, nil
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
@@ -2127,14 +2130,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, skProto := s.Type()
+ if family != linux.AF_INET6 {
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
+ if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+ ep.SocketOptions().SetV6Only(v != 0)
+ return nil
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
@@ -2173,7 +2193,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ ep.SocketOptions().SetReceiveTClass(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2181,7 +2202,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return syserr.ErrProtocolNotAvailable
}
@@ -2256,6 +2277,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
// setSockOptIP implements SetSockOpt when level is SOL_IP.
func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2317,7 +2343,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+ ep.SocketOptions().SetMulticastLoop(v != 0)
+ return nil
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2353,7 +2380,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ ep.SocketOptions().SetReceiveTOS(v != 0)
+ return nil
case linux.IP_PKTINFO:
if len(optVal) == 0 {
@@ -2363,7 +2391,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ ep.SocketOptions().SetReceivePacketInfo(v != 0)
+ return nil
case linux.IP_HDRINCL:
if len(optVal) == 0 {
@@ -2373,7 +2402,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ ep.SocketOptions().SetHeaderIncluded(v != 0)
+ return nil
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
@@ -2515,7 +2545,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
switch name {
case linux.IP_TOS,
linux.IP_TTL,
- linux.IP_HDRINCL,
linux.IP_OPTIONS,
linux.IP_ROUTER_ALERT,
linux.IP_RECVOPTS,
@@ -3384,6 +3413,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
return rv
}
+func isTCPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP)
+}
+
+func isUDPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP)
+}
+
+func isICMPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6)
+}
+
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
func (s *socketOpsCommon) State() uint32 {
@@ -3393,7 +3434,7 @@ func (s *socketOpsCommon) State() uint32 {
}
switch {
- case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ case isTCPSocket(s.skType, s.protocol):
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -3422,7 +3463,7 @@ func (s *socketOpsCommon) State() uint32 {
// Internal or unknown state.
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ case isUDPSocket(s.skType, s.protocol):
// UDP socket.
switch udp.EndpointState(s.Endpoint.State()) {
case udp.StateInitial, udp.StateBound, udp.StateClosed:
@@ -3432,7 +3473,7 @@ func (s *socketOpsCommon) State() uint32 {
default:
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ case isICMPSocket(s.skType, s.protocol):
// TODO(b/112063468): Export states for ICMP sockets.
case s.skType == linux.SOCK_RAW:
// TODO(b/112063468): Export states for raw sockets.
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 1b1188ee5..cced4d8fc 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -19,10 +19,8 @@ import (
)
// SocketOptionsHandler holds methods that help define endpoint specific
-// behavior for socket options.
-// These must be implemented by endpoints to:
-// - Get notified when socket level options are set.
-// - Provide endpoint specific socket options.
+// behavior for socket level socket options. These must be implemented by
+// endpoints to get notified when socket level options are set.
type SocketOptionsHandler interface {
// OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
OnReuseAddressSet(v bool)
@@ -32,10 +30,6 @@ type SocketOptionsHandler interface {
// OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint.
OnKeepAliveSet(v bool)
-
- // IsListening is invoked to fetch SO_ACCEPTCONN option value for an
- // endpoint. It is used to indicate if the socket is a listening socket.
- IsListening() bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -53,11 +47,8 @@ func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet.
func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {}
-// IsListening implements SocketOptionsHandler.IsListening.
-func (*DefaultSocketOptionsHandler) IsListening() bool { return false }
-
-// SocketOptions contains all the variables which store values for SOL_SOCKET
-// level options.
+// SocketOptions contains all the variables which store values for SOL_SOCKET,
+// SOL_IP and SOL_IPV6 level options.
//
// +stateify savable
type SocketOptions struct {
@@ -88,6 +79,31 @@ type SocketOptions struct {
// keepAliveEnabled determines whether TCP keepalive is enabled for this
// socket.
keepAliveEnabled uint32
+
+ // multicastLoopEnabled determines whether multicast packets sent over a
+ // non-loopback interface will be looped back. Analogous to inet->mc_loop.
+ multicastLoopEnabled uint32
+
+ // receiveTOSEnabled is used to specify if the TOS ancillary message is
+ // passed with incoming packets.
+ receiveTOSEnabled uint32
+
+ // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary
+ // message is passed with incoming packets.
+ receiveTClassEnabled uint32
+
+ // receivePacketInfoEnabled is used to specify if more inforamtion is
+ // provided with incoming packets such as interface index and address.
+ receivePacketInfoEnabled uint32
+
+ // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
+ // being written have an IP header and the endpoint should not attach an IP
+ // header.
+ hdrIncludedEnabled uint32
+
+ // v6OnlyEnabled is used to determine whether an IPv6 socket is to be
+ // restricted to sending and receiving IPv6 packets only.
+ v6OnlyEnabled uint32
}
// InitHandler initializes the handler. This must be called before using the
@@ -167,8 +183,64 @@ func (so *SocketOptions) SetKeepAlive(v bool) {
so.handler.OnKeepAliveSet(v)
}
-// GetAcceptConn gets value for SO_ACCEPTCONN option.
-func (so *SocketOptions) GetAcceptConn() bool {
- // This option is completely endpoint dependent and unsettable.
- return so.handler.IsListening()
+// GetMulticastLoop gets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) GetMulticastLoop() bool {
+ return atomic.LoadUint32(&so.multicastLoopEnabled) != 0
+}
+
+// SetMulticastLoop sets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) SetMulticastLoop(v bool) {
+ storeAtomicBool(&so.multicastLoopEnabled, v)
+}
+
+// GetReceiveTOS gets value for IP_RECVTOS option.
+func (so *SocketOptions) GetReceiveTOS() bool {
+ return atomic.LoadUint32(&so.receiveTOSEnabled) != 0
+}
+
+// SetReceiveTOS sets value for IP_RECVTOS option.
+func (so *SocketOptions) SetReceiveTOS(v bool) {
+ storeAtomicBool(&so.receiveTOSEnabled, v)
+}
+
+// GetReceiveTClass gets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) GetReceiveTClass() bool {
+ return atomic.LoadUint32(&so.receiveTClassEnabled) != 0
+}
+
+// SetReceiveTClass sets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) SetReceiveTClass(v bool) {
+ storeAtomicBool(&so.receiveTClassEnabled, v)
+}
+
+// GetReceivePacketInfo gets value for IP_PKTINFO option.
+func (so *SocketOptions) GetReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0
+}
+
+// SetReceivePacketInfo sets value for IP_PKTINFO option.
+func (so *SocketOptions) SetReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receivePacketInfoEnabled, v)
+}
+
+// GetHeaderIncluded gets value for IP_HDRINCL option.
+func (so *SocketOptions) GetHeaderIncluded() bool {
+ return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
+}
+
+// SetHeaderIncluded sets value for IP_HDRINCL option.
+func (so *SocketOptions) SetHeaderIncluded(v bool) {
+ storeAtomicBool(&so.hdrIncludedEnabled, v)
+}
+
+// GetV6Only gets value for IPV6_V6ONLY option.
+func (so *SocketOptions) GetV6Only() bool {
+ return atomic.LoadUint32(&so.v6OnlyEnabled) != 0
+}
+
+// SetV6Only sets value for IPV6_V6ONLY option.
+//
+// Preconditions: the backing TCP or UDP endpoint must be in initial state.
+func (so *SocketOptions) SetV6Only(v bool) {
+ storeAtomicBool(&so.v6OnlyEnabled, v)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 67e4dfb91..b790b3e97 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -2849,9 +2849,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2885,9 +2883,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3256,9 +3252,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 3c6ec0c3a..081c21fa9 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -65,6 +65,7 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
ep.ops.InitHandler(ep)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 6ed00e74f..2eb6e76af 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -718,37 +718,8 @@ const (
// TCP, it determines if the Nagle algorithm is on or off.
DelayOption
- // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether multicast packets sent over a non-loopback interface
- // will be looped back.
- MulticastLoopOption
-
// QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
QuickAckOption
-
- // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
- // specify if the IPV6_TCLASS ancillary message is passed with incoming
- // packets.
- ReceiveTClassOption
-
- // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
- // if the TOS ancillary message is passed with incoming packets.
- ReceiveTOSOption
-
- // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
- // specify if more inforamtion is provided with incoming packets such as
- // interface index and address.
- ReceiveIPPacketInfoOption
-
- // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether an IPv6 socket is to be restricted to sending and receiving
- // IPv6 packets only.
- V6OnlyOption
-
- // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
- // endpoint that all packets being written have an IP header and the
- // endpoint should not attach an IP header.
- IPHdrIncludedOption
)
// SockOptInt represents socket options which values have the int type.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 0a714498d..5eacd8d24 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -148,6 +148,7 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e2c7a0d62..da402bad9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -549,8 +549,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 2b1022995..0478900c3 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -65,7 +65,6 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
- hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -116,9 +115,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
- hdrIncluded: !associated,
}
e.ops.InitHandler(e)
+ e.ops.SetHeaderIncluded(!associated)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -271,7 +270,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -361,7 +360,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
@@ -538,13 +537,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- e.hdrIncluded = v
- e.mu.Unlock()
- return nil
- }
return tcpip.ErrUnknownProtocolOption
}
@@ -608,16 +600,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- v := e.hdrIncluded
- e.mu.Unlock()
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
+ return false, tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 5f2221f1b..3e1041cbe 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
route.ResolveWith(s.remoteLinkAddr)
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6Only
+ n.ops.SetV6Only(l.v6Only)
n.ID = s.id
n.boundNICID = s.nicID
n.route = route
@@ -752,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
- v6Only := e.v6only
+ v6Only := e.ops.GetV6Only()
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index e38488d4d..31eded0ce 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1078,7 +1078,7 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
- if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
@@ -1635,7 +1635,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
}
extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
if newSyn {
- info := e.EndpointInfo.TransportEndpointInfo
+ info := e.TransportEndpointInfo
newID := info.ID
newID.RemoteAddress = ""
newID.RemotePort = 0
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index a6f25896b..1d1b01a6c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) {
}
}
- // Make sure we get the same error when calling the original ep and the
- // new one. This validates that v4-mapped endpoints are still able to
- // query the V6Only flag, whereas pure v4 endpoints are not.
- _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
- t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
- }
-
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
@@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
- nep, _, err := c.EP.Accept(&addr)
+ _, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept(&addr)
+ _, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) {
if addr.Addr != context.TestV6Addr {
t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
}
-
- // Make sure we can still query the v6 only status of the new endpoint,
- // that is, that it is in fact a v6 socket.
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
- }
}
func TestV4AcceptOnV4(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 713a70b47..fb64851ae 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -309,18 +309,6 @@ type Stats struct {
// marker interface.
func (*Stats) IsEndpointStats() {}
-// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools.
-//
-// +stateify savable
-type EndpointInfo struct {
- stack.TransportEndpointInfo
-}
-
-// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
-// marker interface.
-func (*EndpointInfo) IsEndpointInfo() {}
-
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -361,7 +349,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
//
// +stateify savable
type endpoint struct {
- EndpointInfo
+ stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
@@ -442,7 +430,6 @@ type endpoint struct {
boundNICID tcpip.NICID
route *stack.Route `state:"manual"`
ttl uint8
- v6only bool
isConnectNotified bool
// h stores a reference to the current handshake state if the endpoint is in
@@ -865,11 +852,9 @@ type keepalive struct {
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
stack: s,
- EndpointInfo: EndpointInfo{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
- },
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
},
waiterQueue: waiterQueue,
state: StateInitial,
@@ -888,6 +873,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
maxSynRetries: DefaultSynRetries,
}
e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -1686,21 +1672,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
o = 0
}
atomic.StoreUint32(&e.slowAck, o)
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- // We only allow this to be set when we're in the initial state.
- if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.LockUser()
- e.v6only = v
- e.UnlockUser()
}
return nil
@@ -1985,13 +1956,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
-// IsListening implements tcpip.SocketOptionsHandler.IsListening.
-func (e *endpoint) IsListening() bool {
- e.LockUser()
- defer e.UnlockUser()
- return e.EndpointState() == StateListen
-}
-
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
@@ -2006,21 +1970,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
v := atomic.LoadUint32(&e.slowAck) == 0
return v, nil
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.LockUser()
- v := e.v6only
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.MulticastLoopOption:
- return true, nil
-
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -2182,7 +2131,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -2716,7 +2665,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// v6only set to false.
if netProto == header.IPv6ProtocolNumber {
stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
- alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4
if alsoBindToV4 {
netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
@@ -3180,7 +3129,7 @@ func (e *endpoint) State() uint32 {
func (e *endpoint) Info() tcpip.EndpointInfo {
e.LockUser()
// Make a copy of the endpoint info.
- ret := e.EndpointInfo
+ ret := e.TransportEndpointInfo
e.UnlockUser()
return &ret
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 7124a715d..dfe2b4c6c 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -4642,13 +4642,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
case "dual":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(false)
default:
t.Fatalf("unknown network: '%s'", network)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e6aa4fc4b..010a23e45 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetV6Only(v6only)
}
// GetV6Packet reads a single packet from the link layer endpoint of the context
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9d33a694b..a9c74148b 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -101,12 +101,10 @@ type endpoint struct {
state EndpointState
route *stack.Route `state:"manual"`
dstPort uint16
- v6only bool
ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
- multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
@@ -122,17 +120,6 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
- // receiveTOS determines if the incoming IPv4 TOS header field is passed
- // as ancillary data to ControlMessages on Read.
- receiveTOS bool
-
- // receiveTClass determines if the incoming IPv6 TClass header field is
- // passed as ancillary data to ControlMessages on Read.
- receiveTClass bool
-
- // receiveIPPacketInfo determines if the packet info is returned by Read.
- receiveIPPacketInfo bool
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -188,7 +175,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
@@ -196,6 +182,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -307,21 +294,16 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
HasTimestamp: true,
Timestamp: p.timestamp,
}
- e.mu.RLock()
- receiveTOS := e.receiveTOS
- receiveTClass := e.receiveTClass
- receiveIPPacketInfo := e.receiveIPPacketInfo
- e.mu.RUnlock()
- if receiveTOS {
+ if e.ops.GetReceiveTOS() {
cm.HasTOS = true
cm.TOS = p.tos
}
- if receiveTClass {
+ if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
- if receiveIPPacketInfo {
+ if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
@@ -388,7 +370,7 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
if err != nil {
return nil, 0, err
}
@@ -595,48 +577,6 @@ func (e *endpoint) OnReusePortSet(v bool) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTOSOption:
- e.mu.Lock()
- e.receiveTOS = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrNotSupported
- }
-
- e.mu.Lock()
- e.receiveTClass = v
- e.mu.Unlock()
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v
- }
return nil
}
@@ -851,51 +791,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTOSOption:
- e.mu.RLock()
- v := e.receiveTOS
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrNotSupported
- }
-
- e.mu.RLock()
- v := e.receiveTClass
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.RLock()
- v := e.v6only
- e.mu.RUnlock()
-
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
+ return false, tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -1036,7 +932,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -1147,7 +1043,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -1259,7 +1155,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 99f3fc37f..9d06035ea 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -114,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
var err *tcpip.Error
if e.state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 1233bab14..e384f52dd 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -363,9 +363,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetV6Only(true)
} else if flow.isBroadcast() {
c.ep.SocketOptions().SetBroadcast(true)
}
@@ -1414,9 +1412,7 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
- t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
- }
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
NIC: 1,
@@ -1629,13 +1625,15 @@ func TestSetTClass(t *testing.T) {
}
func TestReceiveTosTClass(t *testing.T) {
+ const RcvTOSOpt = "ReceiveTosOption"
+ const RcvTClassOpt = "ReceiveTClassOption"
+
testCases := []struct {
- name string
- getReceiveOption tcpip.SockOptBool
- tests []testFlow
+ name string
+ tests []testFlow
}{
- {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
- {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
+ {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1644,29 +1642,32 @@ func TestReceiveTosTClass(t *testing.T) {
defer c.cleanup()
c.createEndpointForFlow(flow)
- option := testCase.getReceiveOption
name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ var optionGetter func() bool
+ var optionSetter func(bool)
+ switch name {
+ case RcvTOSOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTOS
+ optionSetter = c.ep.SocketOptions().SetReceiveTOS
+ case RcvTClassOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTClass
+ optionSetter = c.ep.SocketOptions().SetReceiveTClass
+ default:
+ t.Fatalf("unkown test variant: %s", name)
}
+
+ // Verify that setting and reading the option works.
+ v := optionGetter()
// Test for expected default value.
if v != false {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
want := true
- if err := c.ep.SetSockOptBool(option, want); err != nil {
- c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
- }
-
- got, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
- }
+ optionSetter(want)
+ got := optionGetter()
if got != want {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
}
@@ -1676,10 +1677,10 @@ func TestReceiveTosTClass(t *testing.T) {
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %s", err)
}
- switch option {
- case tcpip.ReceiveTClassOption:
+ switch name {
+ case RcvTClassOpt:
testRead(c, flow, checker.ReceiveTClass(testTOS))
- case tcpip.ReceiveTOSOption:
+ case RcvTOSOpt:
testRead(c, flow, checker.ReceiveTOS(testTOS))
default:
t.Fatalf("unknown test variant: %s", name)