diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1c54dc180..e2bd57ebf 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -131,8 +130,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } - rand.Read(l.nonce[0][:]) - rand.Read(l.nonce[1][:]) + for i := range l.nonce { + if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil { + panic(err) + } + } return l } @@ -150,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // Feed everything to the hasher. l.hasherMu.Lock() l.hasher.Reset() + + // Per hash.Hash.Writer: + // + // It never returns an error. l.hasher.Write(payload[:]) l.hasher.Write(l.nonce[nonceIndex][:]) - io.WriteString(l.hasher, string(id.LocalAddress)) - io.WriteString(l.hasher, string(id.RemoteAddress)) + l.hasher.Write([]byte(id.LocalAddress)) + l.hasher.Write([]byte(id.RemoteAddress)) // Finalize the calculation of the hash and return the first 4 bytes. - h := make([]byte, 0, sha1.Size) - h = l.hasher.Sum(h) + h := l.hasher.Sum(nil) l.hasherMu.Unlock() return binary.BigEndian.Uint32(h[:]) @@ -196,7 +201,7 @@ func (l *listenContext) useSynCookies() bool { // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -243,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err } @@ -655,7 +660,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { return err } |