diff options
author | Tamir Duberstein <tamird@google.com> | 2021-04-10 14:52:00 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-04-10 14:53:55 -0700 |
commit | c84ff991240c0ec71dd1978db250bcbfbe4c142b (patch) | |
tree | 721d5bf6b26139a5cedd6b9e04b7e71c4db0c069 | |
parent | 2fea7d096b6224da50e09fa4bace7f3c203ed074 (diff) |
Use the SecureRNG to generate listener nonces
Some other cleanup while I'm here:
- Remove unused arguments
- Handle some unhandled errors
- Remove redundant casts
- Remove redundant parens
- Avoid shadowing `hash` package name
PiperOrigin-RevId: 367816161
-rw-r--r-- | pkg/tcpip/hash/jenkins/jenkins.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 24 |
3 files changed, 42 insertions, 27 deletions
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go index 52c22230e..33ff22a7b 100644 --- a/pkg/tcpip/hash/jenkins/jenkins.go +++ b/pkg/tcpip/hash/jenkins/jenkins.go @@ -42,26 +42,26 @@ func (s *Sum32) Reset() { *s = 0 } // Sum32 returns the hash value func (s *Sum32) Sum32() uint32 { - hash := *s + sCopy := *s - hash += (hash << 3) - hash ^= hash >> 11 - hash += hash << 15 + sCopy += sCopy << 3 + sCopy ^= sCopy >> 11 + sCopy += sCopy << 15 - return uint32(hash) + return uint32(sCopy) } // Write adds more data to the running hash. // // It never returns an error. func (s *Sum32) Write(data []byte) (int, error) { - hash := *s + sCopy := *s for _, b := range data { - hash += Sum32(b) - hash += hash << 10 - hash ^= hash >> 6 + sCopy += Sum32(b) + sCopy += sCopy << 10 + sCopy ^= sCopy >> 6 } - *s = hash + *s = sCopy return len(data), nil } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1c54dc180..e2bd57ebf 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -131,8 +130,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } - rand.Read(l.nonce[0][:]) - rand.Read(l.nonce[1][:]) + for i := range l.nonce { + if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil { + panic(err) + } + } return l } @@ -150,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // Feed everything to the hasher. l.hasherMu.Lock() l.hasher.Reset() + + // Per hash.Hash.Writer: + // + // It never returns an error. l.hasher.Write(payload[:]) l.hasher.Write(l.nonce[nonceIndex][:]) - io.WriteString(l.hasher, string(id.LocalAddress)) - io.WriteString(l.hasher, string(id.RemoteAddress)) + l.hasher.Write([]byte(id.LocalAddress)) + l.hasher.Write([]byte(id.RemoteAddress)) // Finalize the calculation of the hash and return the first 4 bytes. - h := make([]byte, 0, sha1.Size) - h = l.hasher.Sum(h) + h := l.hasher.Sum(nil) l.hasherMu.Unlock() return binary.BigEndian.Uint32(h[:]) @@ -196,7 +201,7 @@ func (l *listenContext) useSynCookies() bool { // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -243,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err } @@ -655,7 +660,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9fbaf6f4b..1060a0a90 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -877,7 +877,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, - sndMTU: int(math.MaxInt32), + sndMTU: math.MaxInt32, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -1703,7 +1703,7 @@ func (e *endpoint) OnReusePortSet(v bool) { } // OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. -func (e *endpoint) OnKeepAliveSet(v bool) { +func (e *endpoint) OnKeepAliveSet(bool) { e.notifyProtocolGoroutine(notifyKeepaliveChanged) } @@ -2235,12 +2235,22 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.Seed()) - h.Write([]byte(e.ID.LocalAddress)) - h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h.Write(portBuf) + + h := jenkins.Sum32(e.stack.Seed()) + for _, s := range [][]byte{ + []byte(e.ID.LocalAddress), + []byte(e.ID.RemoteAddress), + portBuf, + } { + // Per io.Writer.Write: + // + // Write must return a non-nil error if it returns n < len(p). + if _, err := h.Write(s); err != nil { + panic(err) + } + } portOffset := uint16(h.Sum32()) var twReuse tcpip.TCPTimeWaitReuseOption @@ -2807,7 +2817,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndBufUsed < sendBufferSize>>1 e.sndBufMu.Unlock() if notify { |