summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2021-05-27 01:20:10 +0000
committergVisor bot <gvisor-bot@google.com>2021-05-27 01:20:10 +0000
commitc81dd74d61e4cb78c8e5526fbf47084a97af8b9f (patch)
tree55f006c48392655464f8ff15e99275bd0483d916 /pkg/tcpip/stack
parent04322f810ff9ebbec03962e6ae43e2788a7bcd0f (diff)
parent097efe81a19a6ee11738957a3091e99a2caa46d4 (diff)
Merge release-20210518.0-52-g097efe81a (automated)
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/stack.go55
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go8
3 files changed, 29 insertions, 38 deletions
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index d2f666c09..0a26f6dd8 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
+func DefaultTables(seed uint32) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,7 @@ func DefaultTables() *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: generateRandUint32(),
+ seed: seed,
},
reaperDone: make(chan struct{}, 1),
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 1ebf9670c..40d277312 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,7 +20,6 @@
package stack
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
@@ -223,10 +222,16 @@ type Options struct {
// RandSource must be thread-safe.
RandSource rand.Source
- // IPTables are the initial iptables rules. If nil, iptables will allow
+ // IPTables are the initial iptables rules. If nil, DefaultIPTables will be
+ // used to construct the initial iptables rules.
// all traffic.
IPTables *IPTables
+ // DefaultIPTables is an optional iptables rules constructor that is called
+ // if IPTables is nil. If both fields are nil, iptables will allow all
+ // traffic.
+ DefaultIPTables func(uint32) *IPTables
+
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
}
@@ -324,23 +329,32 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = cryptorand.Reader
+ }
+
randSrc := opts.RandSource
if randSrc == nil {
+ var v int64
+ if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
// Source provided by rand.NewSource is not thread-safe so
// we wrap it in a simple thread-safe version.
- randSrc = &lockedRandomSource{src: rand.NewSource(generateRandInt64())}
+ randSrc = &lockedRandomSource{src: rand.NewSource(v)}
}
+ randomGenerator := rand.New(randSrc)
+ seed := randomGenerator.Uint32()
if opts.IPTables == nil {
- opts.IPTables = DefaultTables()
+ if opts.DefaultIPTables == nil {
+ opts.DefaultIPTables = DefaultTables
+ }
+ opts.IPTables = opts.DefaultIPTables(seed)
}
opts.NUDConfigs.resetInvalidFields()
- if opts.SecureRNG == nil {
- opts.SecureRNG = cryptorand.Reader
- }
-
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -353,11 +367,11 @@ func New(opts Options) *Stack {
handleLocal: opts.HandleLocal,
tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
+ seed: seed,
nudConfigs: opts.NUDConfigs,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- randomGenerator: rand.New(randSrc),
+ randomGenerator: randomGenerator,
secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
@@ -1822,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader {
return s.secureRNG
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := cryptorand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
-func generateRandInt64() int64 {
- b := make([]byte, 8)
- if _, err := cryptorand.Read(b); err != nil {
- panic(err)
- }
- buf := bytes.NewReader(b)
- var v int64
- if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
- panic(err)
- }
- return v
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 80ad1a9d4..8a8454a6a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,8 +16,6 @@ package stack
import (
"fmt"
- "math/rand"
-
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ seed: d.stack.Seed(),
}
eps.endpoints[id] = epsByNIC
}
@@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
return nil
}
- return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+ return epsByNIC.checkEndpoint(flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it