diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 12 |
8 files changed, 28 insertions, 22 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c73890c4c..e0c5e5e28 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -119,8 +119,7 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // seed is a one-time random value initialized at stack startup - // and is used to seed the TCP port picking on active connections + // seed is a one-time random value initialized at stack startup. // // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 @@ -1819,14 +1818,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol return nic.setNUDConfigs(proto, c) } -// Seed returns a 32 bit value that can be used as a seed value for port -// picking, ISN generation etc. -// -// NOTE: The seed is generated once during stack initialization only. -func (s *Stack) Seed() uint32 { - return s.seed -} - // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. func (s *Stack) Rand() *rand.Rand { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index dda57e225..824cf6526 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -479,7 +479,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: d.stack.Seed(), + seed: d.stack.seed, } } if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index aa413ad05..e4cc1b9f1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 { // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack + stack *stack.Stack + protocol *protocol // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. @@ -119,9 +120,10 @@ func timeStamp(clock tcpip.Clock) uint32 { } // newListenContext creates a new listen context. -func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ stack: stk, + protocol: protocol, rcvWnd: rcvWnd, hasher: sha1.New(), v6Only: v6Only, @@ -213,7 +215,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header return nil, err } - n := newEndpoint(l.stack, netProto, queue) + n := newEndpoint(l.stack, l.protocol, netProto, queue) n.ops.SetV6Only(l.v6Only) n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID @@ -247,7 +249,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -779,7 +781,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.ops.GetV6Only() - ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) + ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 93ed161f9..e0f5e41b2 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -150,7 +150,7 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret) } // generateSecureISN generates a secure Initial Sequence number based on the diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 355719beb..1fc49033f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -378,6 +378,7 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` + protocol *protocol `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 @@ -803,9 +804,10 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: s, + stack: s, + protocol: protocol, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: header.TCPProtocolNumber, @@ -2198,7 +2200,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h := jenkins.Sum32(e.stack.Seed()) + h := jenkins.Sum32(e.protocol.portOffsetSecret) for _, s := range [][]byte{ []byte(e.ID.LocalAddress), []byte(e.ID.RemoteAddress), diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 952ccacdd..f2e8b3840 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) { snd.probeTimer.init(s.Clock(), &snd.probeWaker) } e.stack = s + e.protocol = protocolFromStack(s) e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := EndpointState(e.origEndpointState) diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2e709ed78..78745ea86 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), - listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), + listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), } } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 18b834243..0b30cd3bb 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -96,6 +96,9 @@ type protocol struct { maxRetries uint32 synRetries uint8 dispatcher dispatcher + // The following secrets are initialized once and stay unchanged after. + seqnumSecret uint32 + portOffsetSecret uint32 } // Number returns the tcp protocol number. @@ -105,7 +108,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { // NewEndpoint creates a new tcp endpoint. func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return newEndpoint(p.stack, netProto, waiterQueue), nil + return newEndpoint(p.stack, p, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently @@ -479,7 +482,14 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { maxRTO: MaxRTO, maxRetries: MaxRetries, recovery: tcpip.TCPRACKLossDetection, + seqnumSecret: s.Rand().Uint32(), + portOffsetSecret: s.Rand().Uint32(), } p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } + +// protocolFromStack retrieves the tcp.protocol instance from stack s. +func protocolFromStack(s *stack.Stack) *protocol { + return s.TransportProtocolInstance(ProtocolNumber).(*protocol) +} |