summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/ports/ports.go10
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/stack.go55
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go8
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go6
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
11 files changed, 51 insertions, 77 deletions
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b5b013b64..854d6a6ba 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
// Wildcard destinations affect all destinations for TupleOnly.
if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
+ intersection &= (^TupleOnlyFlag) | counter.SharedFlags()
count++
}
}
@@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error)
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint32(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rng.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
@@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester)
// An optional PortTester can be passed in which if provided will be used to
// test if the picked port can be used. The function should return true if the
// port is safe to use, false otherwise.
-func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv
}
// A port wasn't specified, so try to find one.
- return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
if !pm.reserveSpecificPortLocked(res) {
return false, nil
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index d2f666c09..0a26f6dd8 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
+func DefaultTables(seed uint32) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,7 @@ func DefaultTables() *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: generateRandUint32(),
+ seed: seed,
},
reaperDone: make(chan struct{}, 1),
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 1ebf9670c..40d277312 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,7 +20,6 @@
package stack
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
@@ -223,10 +222,16 @@ type Options struct {
// RandSource must be thread-safe.
RandSource rand.Source
- // IPTables are the initial iptables rules. If nil, iptables will allow
+ // IPTables are the initial iptables rules. If nil, DefaultIPTables will be
+ // used to construct the initial iptables rules.
// all traffic.
IPTables *IPTables
+ // DefaultIPTables is an optional iptables rules constructor that is called
+ // if IPTables is nil. If both fields are nil, iptables will allow all
+ // traffic.
+ DefaultIPTables func(uint32) *IPTables
+
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
}
@@ -324,23 +329,32 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = cryptorand.Reader
+ }
+
randSrc := opts.RandSource
if randSrc == nil {
+ var v int64
+ if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
// Source provided by rand.NewSource is not thread-safe so
// we wrap it in a simple thread-safe version.
- randSrc = &lockedRandomSource{src: rand.NewSource(generateRandInt64())}
+ randSrc = &lockedRandomSource{src: rand.NewSource(v)}
}
+ randomGenerator := rand.New(randSrc)
+ seed := randomGenerator.Uint32()
if opts.IPTables == nil {
- opts.IPTables = DefaultTables()
+ if opts.DefaultIPTables == nil {
+ opts.DefaultIPTables = DefaultTables
+ }
+ opts.IPTables = opts.DefaultIPTables(seed)
}
opts.NUDConfigs.resetInvalidFields()
- if opts.SecureRNG == nil {
- opts.SecureRNG = cryptorand.Reader
- }
-
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -353,11 +367,11 @@ func New(opts Options) *Stack {
handleLocal: opts.HandleLocal,
tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
+ seed: seed,
nudConfigs: opts.NUDConfigs,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- randomGenerator: rand.New(randSrc),
+ randomGenerator: randomGenerator,
secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
@@ -1822,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader {
return s.secureRNG
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := cryptorand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
-func generateRandInt64() int64 {
- b := make([]byte, 8)
- if _, err := cryptorand.Read(b); err != nil {
- panic(err)
- }
- buf := bytes.NewReader(b)
- var v int64
- if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
- panic(err)
- }
- return v
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 80ad1a9d4..8a8454a6a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,8 +16,6 @@ package stack
import (
"fmt"
- "math/rand"
-
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ seed: d.stack.Seed(),
}
eps.endpoints[id] = epsByNIC
}
@@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
return nil
}
- return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+ return epsByNIC.checkEndpoint(flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d92b123c4..39f526023 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -616,7 +616,7 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err.(type) {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index db75228a1..2c65b737d 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -591,7 +591,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset()),
+ TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 08849d0cc..570e5081c 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -19,7 +19,6 @@ import (
"math"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -147,11 +146,6 @@ func FindWndScale(wnd seqnum.Size) int {
// resetState resets the state of the handshake object such that it becomes
// ready for a new 3-way handshake.
func (h *handshake) resetState() {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
-
h.state = handshakeSynSent
h.flags = header.TCPFlagSyn
h.ackNum = 0
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 5963a169c..dff7cb89c 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -16,8 +16,8 @@ package tcp
import (
"encoding/binary"
+ "math/rand"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -142,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) {
// in-order.
type dispatcher struct {
processors []processor
- seed uint32
- wg sync.WaitGroup
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+ wg sync.WaitGroup
}
-func (d *dispatcher) init(nProcessors int) {
+func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
d.close()
d.wait()
d.processors = make([]processor, nProcessors)
- d.seed = generateRandUint32()
+ d.seed = rng.Uint32()
for i := range d.processors {
p := &d.processors[i]
p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
@@ -212,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans
d.selectProcessor(id).queueEndpoint(ep)
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
var payload [4]byte
binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d44f480ab..a27e2110b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,12 +20,12 @@ import (
"fmt"
"io"
"math"
+ "math/rand"
"runtime"
"strings"
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -882,7 +882,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset()
+ e.TSOffset = timeStampOffset(e.stack.Rand())
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -2215,7 +2215,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
@@ -2262,7 +2262,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2598,7 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) {
id := e.TransportEndpointInfo.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2878,11 +2878,7 @@ func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
// timeStampOffset returns a randomized timestamp offset to be used when sending
// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
+func timeStampOffset(rng *rand.Rand) uint32 {
// Initialize a random tsOffset that will be added to the recentTS
// everytime the timestamp is sent when the Timestamp option is enabled.
//
@@ -2892,7 +2888,7 @@ func timeStampOffset() uint32 {
// NOTE: This is not completely to spec as normally this should be
// initialized in a manner analogous to how sequence numbers are
// randomized per connection basis. But for now this is sufficient.
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return rng.Uint32()
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 7c38ad480..2fc282e73 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
// TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection.
recovery: 0,
}
- p.dispatcher.init(runtime.GOMAXPROCS(0))
+ p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 83b589d09..b964e446b 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1077,7 +1077,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, nil /* testPort */)
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */)
if err != nil {
return id, bindToDevice, err
}