summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2021-04-10 14:52:00 -0700
committergVisor bot <gvisor-bot@google.com>2021-04-10 14:53:55 -0700
commitc84ff991240c0ec71dd1978db250bcbfbe4c142b (patch)
tree721d5bf6b26139a5cedd6b9e04b7e71c4db0c069 /pkg/tcpip/transport
parent2fea7d096b6224da50e09fa4bace7f3c203ed074 (diff)
Use the SecureRNG to generate listener nonces
Some other cleanup while I'm here: - Remove unused arguments - Handle some unhandled errors - Remove redundant casts - Remove redundant parens - Avoid shadowing `hash` package name PiperOrigin-RevId: 367816161
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go24
2 files changed, 32 insertions, 17 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 1c54dc180..e2bd57ebf 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -23,7 +23,6 @@ import (
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -131,8 +130,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
- rand.Read(l.nonce[0][:])
- rand.Read(l.nonce[1][:])
+ for i := range l.nonce {
+ if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil {
+ panic(err)
+ }
+ }
return l
}
@@ -150,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc
// Feed everything to the hasher.
l.hasherMu.Lock()
l.hasher.Reset()
+
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:])
- io.WriteString(l.hasher, string(id.LocalAddress))
- io.WriteString(l.hasher, string(id.RemoteAddress))
+ l.hasher.Write([]byte(id.LocalAddress))
+ l.hasher.Write([]byte(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes.
- h := make([]byte, 0, sha1.Size)
- h = l.hasher.Sum(h)
+ h := l.hasher.Sum(nil)
l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:])
@@ -196,7 +201,7 @@ func (l *listenContext) useSynCookies() bool {
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -243,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
}
@@ -655,7 +660,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9fbaf6f4b..1060a0a90 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -877,7 +877,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
waiterQueue: waiterQueue,
state: StateInitial,
rcvBufSize: DefaultReceiveBufferSize,
- sndMTU: int(math.MaxInt32),
+ sndMTU: math.MaxInt32,
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -1703,7 +1703,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
}
// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet.
-func (e *endpoint) OnKeepAliveSet(v bool) {
+func (e *endpoint) OnKeepAliveSet(bool) {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
}
@@ -2235,12 +2235,22 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.Seed())
- h.Write([]byte(e.ID.LocalAddress))
- h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h.Write(portBuf)
+
+ h := jenkins.Sum32(e.stack.Seed())
+ for _, s := range [][]byte{
+ []byte(e.ID.LocalAddress),
+ []byte(e.ID.RemoteAddress),
+ portBuf,
+ } {
+ // Per io.Writer.Write:
+ //
+ // Write must return a non-nil error if it returns n < len(p).
+ if _, err := h.Write(s); err != nil {
+ panic(err)
+ }
+ }
portOffset := uint16(h.Sum32())
var twReuse tcpip.TCPTimeWaitReuseOption
@@ -2807,7 +2817,7 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndBufUsed < int(sendBufferSize)>>1
+ notify = notify && e.sndBufUsed < sendBufferSize>>1
e.sndBufMu.Unlock()
if notify {