diff options
author | gVisor bot <gvisor-bot@google.com> | 2021-05-27 01:20:10 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-05-27 01:20:10 +0000 |
commit | c81dd74d61e4cb78c8e5526fbf47084a97af8b9f (patch) | |
tree | 55f006c48392655464f8ff15e99275bd0483d916 /pkg/tcpip/stack | |
parent | 04322f810ff9ebbec03962e6ae43e2788a7bcd0f (diff) | |
parent | 097efe81a19a6ee11738957a3091e99a2caa46d4 (diff) |
Merge release-20210518.0-52-g097efe81a (automated)
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 55 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 8 |
3 files changed, 29 insertions, 38 deletions
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index d2f666c09..0a26f6dd8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { +func DefaultTables(seed uint32) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,7 @@ func DefaultTables() *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: generateRandUint32(), + seed: seed, }, reaperDone: make(chan struct{}, 1), } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1ebf9670c..40d277312 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,7 +20,6 @@ package stack import ( - "bytes" "encoding/binary" "fmt" "io" @@ -223,10 +222,16 @@ type Options struct { // RandSource must be thread-safe. RandSource rand.Source - // IPTables are the initial iptables rules. If nil, iptables will allow + // IPTables are the initial iptables rules. If nil, DefaultIPTables will be + // used to construct the initial iptables rules. // all traffic. IPTables *IPTables + // DefaultIPTables is an optional iptables rules constructor that is called + // if IPTables is nil. If both fields are nil, iptables will allow all + // traffic. + DefaultIPTables func(uint32) *IPTables + // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader } @@ -324,23 +329,32 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + if opts.SecureRNG == nil { + opts.SecureRNG = cryptorand.Reader + } + randSrc := opts.RandSource if randSrc == nil { + var v int64 + if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil { + panic(err) + } // Source provided by rand.NewSource is not thread-safe so // we wrap it in a simple thread-safe version. - randSrc = &lockedRandomSource{src: rand.NewSource(generateRandInt64())} + randSrc = &lockedRandomSource{src: rand.NewSource(v)} } + randomGenerator := rand.New(randSrc) + seed := randomGenerator.Uint32() if opts.IPTables == nil { - opts.IPTables = DefaultTables() + if opts.DefaultIPTables == nil { + opts.DefaultIPTables = DefaultTables + } + opts.IPTables = opts.DefaultIPTables(seed) } opts.NUDConfigs.resetInvalidFields() - if opts.SecureRNG == nil { - opts.SecureRNG = cryptorand.Reader - } - s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -353,11 +367,11 @@ func New(opts Options) *Stack { handleLocal: opts.HandleLocal, tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), + seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - randomGenerator: rand.New(randSrc), + randomGenerator: randomGenerator, secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, @@ -1822,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader { return s.secureRNG } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := cryptorand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - -func generateRandInt64() int64 { - b := make([]byte, 8) - if _, err := cryptorand.Read(b); err != nil { - panic(err) - } - buf := bytes.NewReader(b) - var v int64 - if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { - panic(err) - } - return v -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 80ad1a9d4..8a8454a6a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,8 +16,6 @@ package stack import ( "fmt" - "math/rand" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + seed: d.stack.Seed(), } eps.endpoints[id] = epsByNIC } @@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum return nil } - return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) + return epsByNIC.checkEndpoint(flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it |