diff options
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index f641f25df..207123d6f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -15,6 +15,7 @@ package rpcinet import ( + "sync/atomic" "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" @@ -57,7 +58,7 @@ type socketOperations struct { // shState is the state of the connection with respect to shutdown. Because // we're mixing non-blocking semantics on the other side we have to adapt for // some strange differences between blocking and non-blocking sockets. - shState tcpip.ShutdownFlags + shState int32 } // Verify that we actually implement socket.Socket. @@ -105,26 +106,35 @@ func translateIOSyscallError(err error) error { // setShutdownFlags will set the shutdown flag so we can handle blocking reads // after a read shutdown. func (s *socketOperations) setShutdownFlags(how int) { + var f tcpip.ShutdownFlags switch how { case linux.SHUT_RD: - s.shState |= tcpip.ShutdownRead + f = tcpip.ShutdownRead case linux.SHUT_WR: - s.shState |= tcpip.ShutdownWrite + f = tcpip.ShutdownWrite case linux.SHUT_RDWR: - s.shState |= tcpip.ShutdownWrite | tcpip.ShutdownRead + f = tcpip.ShutdownWrite | tcpip.ShutdownRead + } + + // Atomically update the flags. + for { + old := atomic.LoadInt32(&s.shState) + if atomic.CompareAndSwapInt32(&s.shState, old, old|int32(f)) { + break + } } } func (s *socketOperations) resetShutdownFlags() { - s.shState = 0 + atomic.StoreInt32(&s.shState, 0) } func (s *socketOperations) isShutRdSet() bool { - return s.shState&tcpip.ShutdownRead != 0 + return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownRead) != 0 } func (s *socketOperations) isShutWrSet() bool { - return s.shState&tcpip.ShutdownWrite != 0 + return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownWrite) != 0 } // Release implements fs.FileOperations.Release. |