summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2021-05-26 18:13:05 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-26 18:15:43 -0700
commit097efe81a19a6ee11738957a3091e99a2caa46d4 (patch)
treed37d778e7379f9a463ec29232cc2ff737bee4284 /pkg
parent522ae2dd1f3c0d5aea52a9883cc1319e3b1ebce4 (diff)
Use the stack RNG everywhere
...except in tests. Note this replaces some uses of a cryptographic RNG with a plain RNG. PiperOrigin-RevId: 376070666
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go4
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go6
-rw-r--r--pkg/tcpip/ports/ports.go10
-rw-r--r--pkg/tcpip/ports/ports_test.go7
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/ndp_test.go32
-rw-r--r--pkg/tcpip/stack/stack.go55
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go8
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go6
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
16 files changed, 82 insertions, 99 deletions
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index f42d73178..e1c4b06fc 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{
// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
// compatibility with netfilter extensions.
-func DefaultLinuxTables() *stack.IPTables {
- tables := stack.DefaultTables()
+func DefaultLinuxTables(seed uint32) *stack.IPTables {
+ tables := stack.DefaultTables(seed)
tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
switch val := oldTarget.(type) {
case *stack.AcceptTarget:
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 71d1c3e28..bc9cf6999 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -16,6 +16,7 @@ package ipv6_test
import (
"bytes"
+ "math/rand"
"testing"
"time"
@@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
+ SecureRNG: &secureRNG,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: test.dadTransmits,
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index b300ed894..0f24500e7 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"bytes"
"context"
+ "math/rand"
"strings"
"testing"
"time"
@@ -1240,8 +1241,9 @@ func TestCheckDuplicateAddress(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
- Clock: clock,
+ Clock: clock,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
DADConfigs: dadConfigs,
})},
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b5b013b64..854d6a6ba 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
// Wildcard destinations affect all destinations for TupleOnly.
if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
+ intersection &= (^TupleOnlyFlag) | counter.SharedFlags()
count++
}
}
@@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error)
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint32(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rng.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
@@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester)
// An optional PortTester can be passed in which if provided will be used to
// test if the picked port can be used. The function should return true if the
// port is safe to use, false otherwise.
-func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv
}
// A port wasn't specified, so try to find one.
- return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
if !pm.reserveSpecificPortLocked(res) {
return false, nil
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 6c4fb8c68..a91b130df 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -18,6 +18,7 @@ import (
"math"
"math/rand"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) {
t.Run(test.tname, func(t *testing.T) {
pm := NewPortManager()
net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for _, test := range test.actions {
first, _ := pm.PortRange()
@@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) {
BindToDevice: test.device,
Dest: test.dest,
}
- gotPort, err := pm.ReservePort(portRes, nil /* testPort */)
+ gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */)
if diff := cmp.Diff(test.want, err); diff != "" {
t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
}
@@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err)
}
- port, err := pm.PickEphemeralPort(test.f)
+ port, err := pm.PickEphemeralPort(rng, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index d2f666c09..0a26f6dd8 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
+func DefaultTables(seed uint32) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,7 @@ func DefaultTables() *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: generateRandUint32(),
+ seed: seed,
},
reaperDone: make(chan struct{}, 1),
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index ac2fa777e..7af5ee953 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -19,11 +19,12 @@ import (
"context"
"encoding/binary"
"fmt"
+ "math/rand"
"testing"
"time"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
+ cryptorand "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -497,8 +498,9 @@ func TestDADResolve(t *testing.T) {
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
- Clock: clock,
- SecureRNG: &secureRNG,
+ Clock: clock,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: stack.DADConfigurations{
@@ -1035,7 +1037,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
@@ -1048,13 +1050,13 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
if managedAddress {
// The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
// of the RA payload.
- raPayload[1] |= (1 << 7)
+ raPayload[1] |= 1 << 7
}
// Populate the Other Configurations flag field.
if otherConfigurations {
// The Other Configurations flag field is the 6th bit of byte #1
// (0-indexing) of the RA payload.
- raPayload[1] |= (1 << 6)
+ raPayload[1] |= 1 << 6
}
opts := ra.Options()
opts.Serialize(optSer)
@@ -3833,12 +3835,12 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
const nicName = "nic1"
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
+ n, err := cryptorand.Read(secretKey)
if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
+ t.Fatalf("cryptorand.Read(_): %s", err)
}
if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ t.Fatalf("got cryptorand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
}
prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
@@ -3946,12 +3948,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
+ n, err := cryptorand.Read(secretKey)
if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
+ t.Fatalf("cryptorand.Read(_): %s", err)
}
if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ t.Fatalf("got cryptorand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
}
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -4311,12 +4313,12 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
+ n, err := cryptorand.Read(secretKey)
if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
+ t.Fatalf("cryptorand.Read(_): %s", err)
}
if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ t.Fatalf("got cryptorand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
}
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 1ebf9670c..40d277312 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,7 +20,6 @@
package stack
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
@@ -223,10 +222,16 @@ type Options struct {
// RandSource must be thread-safe.
RandSource rand.Source
- // IPTables are the initial iptables rules. If nil, iptables will allow
+ // IPTables are the initial iptables rules. If nil, DefaultIPTables will be
+ // used to construct the initial iptables rules.
// all traffic.
IPTables *IPTables
+ // DefaultIPTables is an optional iptables rules constructor that is called
+ // if IPTables is nil. If both fields are nil, iptables will allow all
+ // traffic.
+ DefaultIPTables func(uint32) *IPTables
+
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
}
@@ -324,23 +329,32 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = cryptorand.Reader
+ }
+
randSrc := opts.RandSource
if randSrc == nil {
+ var v int64
+ if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
// Source provided by rand.NewSource is not thread-safe so
// we wrap it in a simple thread-safe version.
- randSrc = &lockedRandomSource{src: rand.NewSource(generateRandInt64())}
+ randSrc = &lockedRandomSource{src: rand.NewSource(v)}
}
+ randomGenerator := rand.New(randSrc)
+ seed := randomGenerator.Uint32()
if opts.IPTables == nil {
- opts.IPTables = DefaultTables()
+ if opts.DefaultIPTables == nil {
+ opts.DefaultIPTables = DefaultTables
+ }
+ opts.IPTables = opts.DefaultIPTables(seed)
}
opts.NUDConfigs.resetInvalidFields()
- if opts.SecureRNG == nil {
- opts.SecureRNG = cryptorand.Reader
- }
-
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -353,11 +367,11 @@ func New(opts Options) *Stack {
handleLocal: opts.HandleLocal,
tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
+ seed: seed,
nudConfigs: opts.NUDConfigs,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- randomGenerator: rand.New(randSrc),
+ randomGenerator: randomGenerator,
secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
@@ -1822,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader {
return s.secureRNG
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := cryptorand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
-func generateRandInt64() int64 {
- b := make([]byte, 8)
- if _, err := cryptorand.Read(b); err != nil {
- panic(err)
- }
- buf := bytes.NewReader(b)
- var v int64
- if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
- panic(err)
- }
- return v
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 80ad1a9d4..8a8454a6a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,8 +16,6 @@ package stack
import (
"fmt"
- "math/rand"
-
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ seed: d.stack.Seed(),
}
eps.endpoints[id] = epsByNIC
}
@@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
return nil
}
- return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+ return epsByNIC.checkEndpoint(flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d92b123c4..39f526023 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -616,7 +616,7 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err.(type) {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index db75228a1..2c65b737d 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -591,7 +591,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset()),
+ TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 08849d0cc..570e5081c 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -19,7 +19,6 @@ import (
"math"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -147,11 +146,6 @@ func FindWndScale(wnd seqnum.Size) int {
// resetState resets the state of the handshake object such that it becomes
// ready for a new 3-way handshake.
func (h *handshake) resetState() {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
-
h.state = handshakeSynSent
h.flags = header.TCPFlagSyn
h.ackNum = 0
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 5963a169c..dff7cb89c 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -16,8 +16,8 @@ package tcp
import (
"encoding/binary"
+ "math/rand"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -142,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) {
// in-order.
type dispatcher struct {
processors []processor
- seed uint32
- wg sync.WaitGroup
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+ wg sync.WaitGroup
}
-func (d *dispatcher) init(nProcessors int) {
+func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
d.close()
d.wait()
d.processors = make([]processor, nProcessors)
- d.seed = generateRandUint32()
+ d.seed = rng.Uint32()
for i := range d.processors {
p := &d.processors[i]
p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
@@ -212,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans
d.selectProcessor(id).queueEndpoint(ep)
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
var payload [4]byte
binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d44f480ab..a27e2110b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,12 +20,12 @@ import (
"fmt"
"io"
"math"
+ "math/rand"
"runtime"
"strings"
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -882,7 +882,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset()
+ e.TSOffset = timeStampOffset(e.stack.Rand())
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -2215,7 +2215,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
@@ -2262,7 +2262,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2598,7 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) {
id := e.TransportEndpointInfo.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2878,11 +2878,7 @@ func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
// timeStampOffset returns a randomized timestamp offset to be used when sending
// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
+func timeStampOffset(rng *rand.Rand) uint32 {
// Initialize a random tsOffset that will be added to the recentTS
// everytime the timestamp is sent when the Timestamp option is enabled.
//
@@ -2892,7 +2888,7 @@ func timeStampOffset() uint32 {
// NOTE: This is not completely to spec as normally this should be
// initialized in a manner analogous to how sequence numbers are
// randomized per connection basis. But for now this is sufficient.
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return rng.Uint32()
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 7c38ad480..2fc282e73 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
// TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection.
recovery: 0,
}
- p.dispatcher.init(runtime.GOMAXPROCS(0))
+ p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 83b589d09..b964e446b 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1077,7 +1077,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, nil /* testPort */)
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */)
if err != nil {
return id, bindToDevice, err
}