summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go11
-rw-r--r--pkg/tcpip/socketops.go14
-rw-r--r--pkg/tcpip/tcpip_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go55
5 files changed, 48 insertions, 49 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9c927efa0..d48b92c66 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1232,12 +1232,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum()))
+ return &v, nil
case linux.SO_ACCEPTCONN:
if outLen < sizeOfInt32 {
@@ -1977,7 +1973,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+ ep.SocketOptions().SetNoChecksum(v != 0)
+ return nil
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index e1b0d6354..cc3d59d9d 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -32,6 +32,10 @@ type SocketOptions struct {
// passCredEnabled determines whether SCM_CREDENTIALS socket control messages
// are enabled.
passCredEnabled uint32
+
+ // noChecksumEnabled determines whether UDP checksum is disabled while
+ // transmitting for this socket.
+ noChecksumEnabled uint32
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -61,3 +65,13 @@ func (so *SocketOptions) GetPassCred() bool {
func (so *SocketOptions) SetPassCred(v bool) {
storeAtomicBool(&so.passCredEnabled, v)
}
+
+// GetNoChecksum gets value for SO_NO_CHECK option.
+func (so *SocketOptions) GetNoChecksum() bool {
+ return atomic.LoadUint32(&so.noChecksumEnabled) != 0
+}
+
+// SetNoChecksum sets value for SO_NO_CHECK option.
+func (so *SocketOptions) SetNoChecksum(v bool) {
+ storeAtomicBool(&so.noChecksumEnabled, v)
+}
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
index 66cad60db..a50f49a2f 100644
--- a/pkg/tcpip/tcpip_state_autogen.go
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -14,6 +14,7 @@ func (so *SocketOptions) StateFields() []string {
return []string{
"broadcastEnabled",
"passCredEnabled",
+ "noChecksumEnabled",
}
}
@@ -23,6 +24,7 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) {
so.beforeSave()
stateSinkObject.Save(0, &so.broadcastEnabled)
stateSinkObject.Save(1, &so.passCredEnabled)
+ stateSinkObject.Save(2, &so.noChecksumEnabled)
}
func (so *SocketOptions) afterLoad() {}
@@ -30,6 +32,7 @@ func (so *SocketOptions) afterLoad() {}
func (so *SocketOptions) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &so.broadcastEnabled)
stateSourceObject.Load(1, &so.passCredEnabled)
+ stateSourceObject.Load(2, &so.noChecksumEnabled)
}
func (f *FullAddress) StateTypeName() string {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 648587137..5aa16bf35 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -108,7 +108,6 @@ type endpoint struct {
multicastLoop bool
portFlags ports.Flags
bindToDevice tcpip.NICID
- noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -550,7 +549,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
localPort := e.ID.LocalPort
sendTOS := e.sendTOS
owner := e.owner
- noChecksum := e.noChecksum
+ noChecksum := e.SocketOptions().GetNoChecksum()
lockReleased = true
e.mu.RUnlock()
@@ -583,11 +582,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
- case tcpip.NoChecksumOption:
- e.mu.Lock()
- e.noChecksum = v
- e.mu.Unlock()
-
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -858,12 +852,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
- case tcpip.NoChecksumOption:
- e.mu.RLock()
- v := e.noChecksum
- e.mu.RUnlock()
- return v, nil
-
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index 6a715cc10..9350a4809 100644
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -72,7 +72,6 @@ func (e *endpoint) StateFields() []string {
"multicastLoop",
"portFlags",
"bindToDevice",
- "noChecksum",
"lastError",
"boundBindToDevice",
"boundPortFlags",
@@ -94,7 +93,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
stateSinkObject.SaveValue(5, rcvBufSizeMaxValue)
var lastErrorValue string = e.saveLastError()
- stateSinkObject.SaveValue(21, lastErrorValue)
+ stateSinkObject.SaveValue(20, lastErrorValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.waiterQueue)
stateSinkObject.Save(2, &e.uniqueID)
@@ -114,19 +113,18 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(17, &e.multicastLoop)
stateSinkObject.Save(18, &e.portFlags)
stateSinkObject.Save(19, &e.bindToDevice)
- stateSinkObject.Save(20, &e.noChecksum)
- stateSinkObject.Save(22, &e.boundBindToDevice)
- stateSinkObject.Save(23, &e.boundPortFlags)
- stateSinkObject.Save(24, &e.sendTOS)
- stateSinkObject.Save(25, &e.receiveTOS)
- stateSinkObject.Save(26, &e.receiveTClass)
- stateSinkObject.Save(27, &e.receiveIPPacketInfo)
- stateSinkObject.Save(28, &e.shutdownFlags)
- stateSinkObject.Save(29, &e.multicastMemberships)
- stateSinkObject.Save(30, &e.effectiveNetProtos)
- stateSinkObject.Save(31, &e.owner)
- stateSinkObject.Save(32, &e.linger)
- stateSinkObject.Save(33, &e.ops)
+ stateSinkObject.Save(21, &e.boundBindToDevice)
+ stateSinkObject.Save(22, &e.boundPortFlags)
+ stateSinkObject.Save(23, &e.sendTOS)
+ stateSinkObject.Save(24, &e.receiveTOS)
+ stateSinkObject.Save(25, &e.receiveTClass)
+ stateSinkObject.Save(26, &e.receiveIPPacketInfo)
+ stateSinkObject.Save(27, &e.shutdownFlags)
+ stateSinkObject.Save(28, &e.multicastMemberships)
+ stateSinkObject.Save(29, &e.effectiveNetProtos)
+ stateSinkObject.Save(30, &e.owner)
+ stateSinkObject.Save(31, &e.linger)
+ stateSinkObject.Save(32, &e.ops)
}
func (e *endpoint) StateLoad(stateSourceObject state.Source) {
@@ -149,21 +147,20 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(17, &e.multicastLoop)
stateSourceObject.Load(18, &e.portFlags)
stateSourceObject.Load(19, &e.bindToDevice)
- stateSourceObject.Load(20, &e.noChecksum)
- stateSourceObject.Load(22, &e.boundBindToDevice)
- stateSourceObject.Load(23, &e.boundPortFlags)
- stateSourceObject.Load(24, &e.sendTOS)
- stateSourceObject.Load(25, &e.receiveTOS)
- stateSourceObject.Load(26, &e.receiveTClass)
- stateSourceObject.Load(27, &e.receiveIPPacketInfo)
- stateSourceObject.Load(28, &e.shutdownFlags)
- stateSourceObject.Load(29, &e.multicastMemberships)
- stateSourceObject.Load(30, &e.effectiveNetProtos)
- stateSourceObject.Load(31, &e.owner)
- stateSourceObject.Load(32, &e.linger)
- stateSourceObject.Load(33, &e.ops)
+ stateSourceObject.Load(21, &e.boundBindToDevice)
+ stateSourceObject.Load(22, &e.boundPortFlags)
+ stateSourceObject.Load(23, &e.sendTOS)
+ stateSourceObject.Load(24, &e.receiveTOS)
+ stateSourceObject.Load(25, &e.receiveTClass)
+ stateSourceObject.Load(26, &e.receiveIPPacketInfo)
+ stateSourceObject.Load(27, &e.shutdownFlags)
+ stateSourceObject.Load(28, &e.multicastMemberships)
+ stateSourceObject.Load(29, &e.effectiveNetProtos)
+ stateSourceObject.Load(30, &e.owner)
+ stateSourceObject.Load(31, &e.linger)
+ stateSourceObject.Load(32, &e.ops)
stateSourceObject.LoadValue(5, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
- stateSourceObject.LoadValue(21, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
+ stateSourceObject.LoadValue(20, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
stateSourceObject.AfterLoad(e.afterLoad)
}