diff options
author | Tamir Duberstein <tamird@google.com> | 2021-05-26 18:13:05 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-05-26 18:15:43 -0700 |
commit | 097efe81a19a6ee11738957a3091e99a2caa46d4 (patch) | |
tree | d37d778e7379f9a463ec29232cc2ff737bee4284 /pkg/tcpip/transport | |
parent | 522ae2dd1f3c0d5aea52a9883cc1319e3b1ebce4 (diff) |
Use the stack RNG everywhere
...except in tests.
Note this replaces some uses of a cryptographic RNG with a plain RNG.
PiperOrigin-RevId: 376070666
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dispatcher.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 |
7 files changed, 17 insertions, 34 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d92b123c4..39f526023 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -616,7 +616,7 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err.(type) { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index db75228a1..2c65b737d 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -591,7 +591,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset()), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 08849d0cc..570e5081c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -19,7 +19,6 @@ import ( "math" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -147,11 +146,6 @@ func FindWndScale(wnd seqnum.Size) int { // resetState resets the state of the handshake object such that it becomes // ready for a new 3-way handshake. func (h *handshake) resetState() { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - h.state = handshakeSynSent h.flags = header.TCPFlagSyn h.ackNum = 0 diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 5963a169c..dff7cb89c 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -16,8 +16,8 @@ package tcp import ( "encoding/binary" + "math/rand" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -142,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) { // in-order. type dispatcher struct { processors []processor - seed uint32 - wg sync.WaitGroup + // seed is a random secret for a jenkins hash. + seed uint32 + wg sync.WaitGroup } -func (d *dispatcher) init(nProcessors int) { +func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.close() d.wait() d.processors = make([]processor, nProcessors) - d.seed = generateRandUint32() + d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) @@ -212,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans d.selectProcessor(id).queueEndpoint(ep) } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { var payload [4]byte binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d44f480ab..a27e2110b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,12 +20,12 @@ import ( "fmt" "io" "math" + "math/rand" "runtime" "strings" "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -882,7 +882,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset() + e.TSOffset = timeStampOffset(e.stack.Rand()) e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) @@ -2215,7 +2215,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2262,7 +2262,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2598,7 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) { id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2878,11 +2878,7 @@ func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } +func timeStampOffset(rng *rand.Rand) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // @@ -2892,7 +2888,7 @@ func timeStampOffset() uint32 { // NOTE: This is not completely to spec as normally this should be // initialized in a manner analogous to how sequence numbers are // randomized per connection basis. But for now this is sufficient. - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return rng.Uint32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 7c38ad480..2fc282e73 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. recovery: 0, } - p.dispatcher.init(runtime.GOMAXPROCS(0)) + p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 83b589d09..b964e446b 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1077,7 +1077,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, nil /* testPort */) + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } |