diff options
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 18 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 26 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 58 |
11 files changed, 84 insertions, 89 deletions
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 3baad098b..057f4d294 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -120,9 +120,6 @@ type socketOpsCommon struct { // fixed buffer but only consume this many bytes. sendBufferSize uint32 - // passcred indicates if this socket wants SCM credentials. - passcred bool - // filter indicates that this socket has a BPF filter "installed". // // TODO(gvisor.dev/issue/1119): We don't actually support filtering, @@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { // Passcred implements transport.Credentialer.Passcred. func (s *socketOpsCommon) Passcred() bool { - s.mu.Lock() - passcred := s.passcred - s.mu.Unlock() - return passcred + return s.ep.SocketOptions().GetPassCred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. @@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] } passcred := usermem.ByteOrder.Uint32(opt) - s.mu.Lock() - s.passcred = passcred != 0 - s.mu.Unlock() + s.ep.SocketOptions().SetPassCred(passcred != 0) return nil case linux.SO_ATTACH_FILTER: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5afe77858..9c927efa0 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -260,10 +260,12 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) - // LastError implements tcpip.Endpoint.LastError. + // LastError implements tcpip.Endpoint.LastError and + // transport.Endpoint.LastError. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions implements tcpip.Endpoint.SocketOptions and + // transport.Endpoint.SocketOptions. SocketOptions() *tcpip.SocketOptions } @@ -1068,13 +1070,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.PasscredOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred())) + return &v, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1923,7 +1920,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) + ep.SocketOptions().SetPassCred(v != 0) + return nil case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 18a50e9f8..0324dcd93 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,8 +16,6 @@ package transport import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -203,10 +201,11 @@ type Endpoint interface { // procfs. State() uint32 - // LastError implements tcpip.Endpoint.LastError. + // LastError clears and returns the last error reported by the endpoint. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions returns the structure which contains all the socket + // level options. SocketOptions() *tcpip.SocketOptions } @@ -740,10 +739,6 @@ func (e *connectedEndpoint) CloseUnread() { type baseEndpoint struct { *waiter.Queue - // passcred specifies whether SCM_CREDENTIALS socket control messages are - // enabled on this endpoint. Must be accessed atomically. - passcred int32 - // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -786,7 +781,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { // Passcred implements Credentialer.Passcred. func (e *baseEndpoint) Passcred() bool { - return atomic.LoadInt32(&e.passcred) != 0 + return e.SocketOptions().GetPassCred() } // ConnectedPasscred implements Credentialer.ConnectedPasscred. @@ -796,14 +791,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool { return e.connected != nil && e.connected.Passcred() } -func (e *baseEndpoint) setPasscred(pc bool) { - if pc { - atomic.StoreInt32(&e.passcred, 1) - } else { - atomic.StoreInt32(&e.passcred, 0) - } -} - // Connected implements ConnectingEndpoint.Connected. func (e *baseEndpoint) Connected() bool { return e.receiver != nil && e.connected != nil @@ -870,8 +857,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { - case tcpip.PasscredOption: - e.setPasscred(v) case tcpip.ReuseAddressOption: default: log.Warningf("Unsupported socket option: %d", opt) @@ -894,9 +879,6 @@ func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil - case tcpip.PasscredOption: - return e.Passcred(), nil - default: log.Warningf("Unsupported socket option: %d", opt) return false, tcpip.ErrUnknownProtocolOption diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 3e520d2ee..b32bb7ba8 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -115,9 +115,6 @@ type socketOpsCommon struct { // bound, they cannot be modified. abstractName string abstractNamespace *kernel.AbstractSocketNamespace - - // ops is used to get socket level options. - ops tcpip.SocketOptions } func (s *socketOpsCommon) isPacket() bool { diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 2a6c7c7c0..e1b0d6354 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,31 +15,49 @@ package tcpip import ( - "gvisor.dev/gvisor/pkg/sync" + "sync/atomic" ) -// SocketOptions contains all the variables which store values for socket +// SocketOptions contains all the variables which store values for SOL_SOCKET // level options. // // +stateify savable type SocketOptions struct { - // mu protects fields below. - mu sync.Mutex `state:"nosave"` - broadcastEnabled bool + // These fields are accessed and modified using atomic operations. + + // broadcastEnabled determines whether datagram sockets are allowed to send + // packets to a broadcast address. + broadcastEnabled uint32 + + // passCredEnabled determines whether SCM_CREDENTIALS socket control messages + // are enabled. + passCredEnabled uint32 +} + +func storeAtomicBool(addr *uint32, v bool) { + var val uint32 + if v { + val = 1 + } + atomic.StoreUint32(addr, val) } // GetBroadcast gets value for SO_BROADCAST option. func (so *SocketOptions) GetBroadcast() bool { - so.mu.Lock() - defer so.mu.Unlock() - - return so.broadcastEnabled + return atomic.LoadUint32(&so.broadcastEnabled) != 0 } // SetBroadcast sets value for SO_BROADCAST option. func (so *SocketOptions) SetBroadcast(v bool) { - so.mu.Lock() - defer so.mu.Unlock() + storeAtomicBool(&so.broadcastEnabled, v) +} + +// GetPassCred gets value for SO_PASSCRED option. +func (so *SocketOptions) GetPassCred() bool { + return atomic.LoadUint32(&so.passCredEnabled) != 0 +} - so.broadcastEnabled = v +// SetPassCred sets value for SO_PASSCRED option. +func (so *SocketOptions) SetPassCred(v bool) { + storeAtomicBool(&so.passCredEnabled, v) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9e83dd1c..09361360f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -721,12 +721,6 @@ const ( // whether UDP checksum is disabled for this socket. NoChecksumOption - // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify - // whether SCM_CREDENTIALS socket control messages are enabled. - // - // Only supported on Unix sockets. - PasscredOption - // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. QuickAckOption diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 440cb0352..fe6514bcd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -857,6 +857,7 @@ func (*endpoint) LastError() *tcpip.Error { return nil } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 4ae1f92ab..0a1e1fbb3 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -756,10 +756,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} +// LastError implements tcpip.Endpoint.LastError. func (*endpoint) LastError() *tcpip.Error { return nil } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 173cd28ec..36b915510 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1279,6 +1279,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } @@ -1299,6 +1300,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error { return err } +// LastError implements tcpip.Endpoint.LastError. func (e *endpoint) LastError() *tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -3213,6 +3215,7 @@ func (e *endpoint) Wait() { } } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 835dcc54e..81601f559 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1535,10 +1535,12 @@ func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +// SocketOptions implements tcpip.Endpoint.SocketOptions. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 796546224..c81ba031d 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -818,32 +818,38 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) { } } -TEST_P(AllSocketPairTest, GetSockoptBroadcast) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - int opt = -1; - socklen_t optlen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, &opt, &optlen), - SyscallSucceeds()); - ASSERT_EQ(optlen, sizeof(opt)); - EXPECT_EQ(opt, 0); -} - -TEST_P(AllSocketPairTest, SetAndGetSocketBroadcastOption) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - int kSockOptOn = 1; - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceedsWithValue(0)); - - int got = -1; - socklen_t length = sizeof(got); - ASSERT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, &got, &length), - SyscallSucceedsWithValue(0)); - - ASSERT_EQ(length, sizeof(got)); - EXPECT_EQ(got, kSockOptOn); +TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) { + int sock_opts[] = {SO_BROADCAST, SO_PASSCRED}; + for (int sock_opt : sock_opts) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int enable = -1; + socklen_t enableLen = sizeof(enable); + + // Test that the option is initially set to false. + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable, + &enableLen), + SyscallSucceeds()); + ASSERT_EQ(enableLen, sizeof(enable)); + EXPECT_EQ(enable, 0) << absl::StrFormat( + "getsockopt(fd, SOL_SOCKET, %d, &enable, &enableLen) => enable=%d", + sock_opt, enable); + + // Test that setting the option to true is reflected in the subsequent + // call to getsockopt(2). + enable = 1; + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable, + sizeof(enable)), + SyscallSucceeds()); + enable = -1; + enableLen = sizeof(enable); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable, + &enableLen), + SyscallSucceeds()); + ASSERT_EQ(enableLen, sizeof(enable)); + EXPECT_EQ(enable, 1) << absl::StrFormat( + "getsockopt(fd, SOL_SOCKET, %d, &enable, &enableLen) => enable=%d", + sock_opt, enable); + } } } // namespace testing |