diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 8 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 2 |
5 files changed, 22 insertions, 27 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9c927efa0..d48b92c66 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1232,12 +1232,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum())) + return &v, nil case linux.SO_ACCEPTCONN: if outLen < sizeOfInt32 { @@ -1977,7 +1973,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + ep.SocketOptions().SetNoChecksum(v != 0) + return nil case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index e1b0d6354..cc3d59d9d 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -32,6 +32,10 @@ type SocketOptions struct { // passCredEnabled determines whether SCM_CREDENTIALS socket control messages // are enabled. passCredEnabled uint32 + + // noChecksumEnabled determines whether UDP checksum is disabled while + // transmitting for this socket. + noChecksumEnabled uint32 } func storeAtomicBool(addr *uint32, v bool) { @@ -61,3 +65,13 @@ func (so *SocketOptions) GetPassCred() bool { func (so *SocketOptions) SetPassCred(v bool) { storeAtomicBool(&so.passCredEnabled, v) } + +// GetNoChecksum gets value for SO_NO_CHECK option. +func (so *SocketOptions) GetNoChecksum() bool { + return atomic.LoadUint32(&so.noChecksumEnabled) != 0 +} + +// SetNoChecksum sets value for SO_NO_CHECK option. +func (so *SocketOptions) SetNoChecksum(v bool) { + storeAtomicBool(&so.noChecksumEnabled, v) +} diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 648587137..5aa16bf35 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -108,7 +108,6 @@ type endpoint struct { multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID - noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -550,7 +549,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c localPort := e.ID.LocalPort sendTOS := e.sendTOS owner := e.owner - noChecksum := e.noChecksum + noChecksum := e.SocketOptions().GetNoChecksum() lockReleased = true e.mu.RUnlock() @@ -583,11 +582,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.multicastLoop = v e.mu.Unlock() - case tcpip.NoChecksumOption: - e.mu.Lock() - e.noChecksum = v - e.mu.Unlock() - case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v @@ -858,12 +852,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil - case tcpip.NoChecksumOption: - e.mu.RLock() - v := e.noChecksum - e.mu.RUnlock() - return v, nil - case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 492e277a8..1233bab14 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1454,16 +1454,12 @@ func TestNoChecksum(t *testing.T) { c.createEndpointForFlow(flow) // Disable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(true) // This option is effective on IPv4 only. testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) // Enable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(false) testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) }) } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index c81ba031d..d17192c36 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -819,7 +819,7 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) { } TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) { - int sock_opts[] = {SO_BROADCAST, SO_PASSCRED}; + int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK}; for (int sock_opt : sock_opts) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int enable = -1; |