summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go11
-rw-r--r--pkg/tcpip/socketops.go14
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go8
-rw-r--r--test/syscalls/linux/socket_generic.cc2
5 files changed, 22 insertions, 27 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9c927efa0..d48b92c66 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1232,12 +1232,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum()))
+ return &v, nil
case linux.SO_ACCEPTCONN:
if outLen < sizeOfInt32 {
@@ -1977,7 +1973,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+ ep.SocketOptions().SetNoChecksum(v != 0)
+ return nil
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index e1b0d6354..cc3d59d9d 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -32,6 +32,10 @@ type SocketOptions struct {
// passCredEnabled determines whether SCM_CREDENTIALS socket control messages
// are enabled.
passCredEnabled uint32
+
+ // noChecksumEnabled determines whether UDP checksum is disabled while
+ // transmitting for this socket.
+ noChecksumEnabled uint32
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -61,3 +65,13 @@ func (so *SocketOptions) GetPassCred() bool {
func (so *SocketOptions) SetPassCred(v bool) {
storeAtomicBool(&so.passCredEnabled, v)
}
+
+// GetNoChecksum gets value for SO_NO_CHECK option.
+func (so *SocketOptions) GetNoChecksum() bool {
+ return atomic.LoadUint32(&so.noChecksumEnabled) != 0
+}
+
+// SetNoChecksum sets value for SO_NO_CHECK option.
+func (so *SocketOptions) SetNoChecksum(v bool) {
+ storeAtomicBool(&so.noChecksumEnabled, v)
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 648587137..5aa16bf35 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -108,7 +108,6 @@ type endpoint struct {
multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
- noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -550,7 +549,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
localPort := e.ID.LocalPort
sendTOS := e.sendTOS
owner := e.owner
- noChecksum := e.noChecksum
+ noChecksum := e.SocketOptions().GetNoChecksum()
lockReleased = true
e.mu.RUnlock()
@@ -583,11 +582,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
- case tcpip.NoChecksumOption:
- e.mu.Lock()
- e.noChecksum = v
- e.mu.Unlock()
-
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -858,12 +852,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
- case tcpip.NoChecksumOption:
- e.mu.RLock()
- v := e.noChecksum
- e.mu.RUnlock()
- return v, nil
-
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 492e277a8..1233bab14 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1454,16 +1454,12 @@ func TestNoChecksum(t *testing.T) {
c.createEndpointForFlow(flow)
// Disable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(true)
// This option is effective on IPv4 only.
testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
// Enable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(false)
testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
})
}
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index c81ba031d..d17192c36 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -819,7 +819,7 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) {
}
TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) {
- int sock_opts[] = {SO_BROADCAST, SO_PASSCRED};
+ int sock_opts[] = {SO_BROADCAST, SO_PASSCRED, SO_NO_CHECK};
for (int sock_opt : sock_opts) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int enable = -1;