diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/tcpip_state_autogen.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_state_autogen.go | 55 |
5 files changed, 48 insertions, 49 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9c927efa0..d48b92c66 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1232,12 +1232,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum())) + return &v, nil case linux.SO_ACCEPTCONN: if outLen < sizeOfInt32 { @@ -1977,7 +1973,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + ep.SocketOptions().SetNoChecksum(v != 0) + return nil case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index e1b0d6354..cc3d59d9d 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -32,6 +32,10 @@ type SocketOptions struct { // passCredEnabled determines whether SCM_CREDENTIALS socket control messages // are enabled. passCredEnabled uint32 + + // noChecksumEnabled determines whether UDP checksum is disabled while + // transmitting for this socket. + noChecksumEnabled uint32 } func storeAtomicBool(addr *uint32, v bool) { @@ -61,3 +65,13 @@ func (so *SocketOptions) GetPassCred() bool { func (so *SocketOptions) SetPassCred(v bool) { storeAtomicBool(&so.passCredEnabled, v) } + +// GetNoChecksum gets value for SO_NO_CHECK option. +func (so *SocketOptions) GetNoChecksum() bool { + return atomic.LoadUint32(&so.noChecksumEnabled) != 0 +} + +// SetNoChecksum sets value for SO_NO_CHECK option. +func (so *SocketOptions) SetNoChecksum(v bool) { + storeAtomicBool(&so.noChecksumEnabled, v) +} diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go index 66cad60db..a50f49a2f 100644 --- a/pkg/tcpip/tcpip_state_autogen.go +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -14,6 +14,7 @@ func (so *SocketOptions) StateFields() []string { return []string{ "broadcastEnabled", "passCredEnabled", + "noChecksumEnabled", } } @@ -23,6 +24,7 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) { so.beforeSave() stateSinkObject.Save(0, &so.broadcastEnabled) stateSinkObject.Save(1, &so.passCredEnabled) + stateSinkObject.Save(2, &so.noChecksumEnabled) } func (so *SocketOptions) afterLoad() {} @@ -30,6 +32,7 @@ func (so *SocketOptions) afterLoad() {} func (so *SocketOptions) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &so.broadcastEnabled) stateSourceObject.Load(1, &so.passCredEnabled) + stateSourceObject.Load(2, &so.noChecksumEnabled) } func (f *FullAddress) StateTypeName() string { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 648587137..5aa16bf35 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -108,7 +108,6 @@ type endpoint struct { multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID - noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -550,7 +549,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c localPort := e.ID.LocalPort sendTOS := e.sendTOS owner := e.owner - noChecksum := e.noChecksum + noChecksum := e.SocketOptions().GetNoChecksum() lockReleased = true e.mu.RUnlock() @@ -583,11 +582,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.multicastLoop = v e.mu.Unlock() - case tcpip.NoChecksumOption: - e.mu.Lock() - e.noChecksum = v - e.mu.Unlock() - case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v @@ -858,12 +852,6 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil - case tcpip.NoChecksumOption: - e.mu.RLock() - v := e.noChecksum - e.mu.RUnlock() - return v, nil - case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index 6a715cc10..9350a4809 100644 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -72,7 +72,6 @@ func (e *endpoint) StateFields() []string { "multicastLoop", "portFlags", "bindToDevice", - "noChecksum", "lastError", "boundBindToDevice", "boundPortFlags", @@ -94,7 +93,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax() stateSinkObject.SaveValue(5, rcvBufSizeMaxValue) var lastErrorValue string = e.saveLastError() - stateSinkObject.SaveValue(21, lastErrorValue) + stateSinkObject.SaveValue(20, lastErrorValue) stateSinkObject.Save(0, &e.TransportEndpointInfo) stateSinkObject.Save(1, &e.waiterQueue) stateSinkObject.Save(2, &e.uniqueID) @@ -114,19 +113,18 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(17, &e.multicastLoop) stateSinkObject.Save(18, &e.portFlags) stateSinkObject.Save(19, &e.bindToDevice) - stateSinkObject.Save(20, &e.noChecksum) - stateSinkObject.Save(22, &e.boundBindToDevice) - stateSinkObject.Save(23, &e.boundPortFlags) - stateSinkObject.Save(24, &e.sendTOS) - stateSinkObject.Save(25, &e.receiveTOS) - stateSinkObject.Save(26, &e.receiveTClass) - stateSinkObject.Save(27, &e.receiveIPPacketInfo) - stateSinkObject.Save(28, &e.shutdownFlags) - stateSinkObject.Save(29, &e.multicastMemberships) - stateSinkObject.Save(30, &e.effectiveNetProtos) - stateSinkObject.Save(31, &e.owner) - stateSinkObject.Save(32, &e.linger) - stateSinkObject.Save(33, &e.ops) + stateSinkObject.Save(21, &e.boundBindToDevice) + stateSinkObject.Save(22, &e.boundPortFlags) + stateSinkObject.Save(23, &e.sendTOS) + stateSinkObject.Save(24, &e.receiveTOS) + stateSinkObject.Save(25, &e.receiveTClass) + stateSinkObject.Save(26, &e.receiveIPPacketInfo) + stateSinkObject.Save(27, &e.shutdownFlags) + stateSinkObject.Save(28, &e.multicastMemberships) + stateSinkObject.Save(29, &e.effectiveNetProtos) + stateSinkObject.Save(30, &e.owner) + stateSinkObject.Save(31, &e.linger) + stateSinkObject.Save(32, &e.ops) } func (e *endpoint) StateLoad(stateSourceObject state.Source) { @@ -149,21 +147,20 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(17, &e.multicastLoop) stateSourceObject.Load(18, &e.portFlags) stateSourceObject.Load(19, &e.bindToDevice) - stateSourceObject.Load(20, &e.noChecksum) - stateSourceObject.Load(22, &e.boundBindToDevice) - stateSourceObject.Load(23, &e.boundPortFlags) - stateSourceObject.Load(24, &e.sendTOS) - stateSourceObject.Load(25, &e.receiveTOS) - stateSourceObject.Load(26, &e.receiveTClass) - stateSourceObject.Load(27, &e.receiveIPPacketInfo) - stateSourceObject.Load(28, &e.shutdownFlags) - stateSourceObject.Load(29, &e.multicastMemberships) - stateSourceObject.Load(30, &e.effectiveNetProtos) - stateSourceObject.Load(31, &e.owner) - stateSourceObject.Load(32, &e.linger) - stateSourceObject.Load(33, &e.ops) + stateSourceObject.Load(21, &e.boundBindToDevice) + stateSourceObject.Load(22, &e.boundPortFlags) + stateSourceObject.Load(23, &e.sendTOS) + stateSourceObject.Load(24, &e.receiveTOS) + stateSourceObject.Load(25, &e.receiveTClass) + stateSourceObject.Load(26, &e.receiveIPPacketInfo) + stateSourceObject.Load(27, &e.shutdownFlags) + stateSourceObject.Load(28, &e.multicastMemberships) + stateSourceObject.Load(29, &e.effectiveNetProtos) + stateSourceObject.Load(30, &e.owner) + stateSourceObject.Load(31, &e.linger) + stateSourceObject.Load(32, &e.ops) stateSourceObject.LoadValue(5, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) }) - stateSourceObject.LoadValue(21, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) + stateSourceObject.LoadValue(20, new(string), func(y interface{}) { e.loadLastError(y.(string)) }) stateSourceObject.AfterLoad(e.afterLoad) } |