From db37483cb6acf55b66132d534bb734f09555b1cf Mon Sep 17 00:00:00 2001
From: Andrei Vagin <avagin@google.com>
Date: Wed, 30 Oct 2019 15:32:20 -0700
Subject: Store endpoints inside multiPortEndpoint in a sorted order

It is required to guarantee the same order of endpoints after save/restore.

PiperOrigin-RevId: 277598665
---
 pkg/tcpip/stack/registration.go      |  3 +++
 pkg/tcpip/stack/stack.go             | 29 +++++++++++++++++++++++++++++
 pkg/tcpip/stack/transport_demuxer.go | 10 ++++++++++
 pkg/tcpip/stack/transport_test.go    | 11 ++++++++---
 4 files changed, 50 insertions(+), 3 deletions(-)

(limited to 'pkg/tcpip/stack')

diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0360187b8..94015ba54 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -60,6 +60,9 @@ const (
 // TransportEndpoint is the interface that needs to be implemented by transport
 // protocol (e.g., tcp, udp) endpoints that can handle packets.
 type TransportEndpoint interface {
+	// UniqueID returns an unique ID for this transport endpoint.
+	UniqueID() uint64
+
 	// HandlePacket is called by the stack when new packets arrive to
 	// this transport endpoint.
 	HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6d6ddc0ff..115a6fcb8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
 import (
 	"encoding/binary"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"golang.org/x/time/rate"
@@ -344,6 +345,13 @@ type ResumableEndpoint interface {
 	Resume(*Stack)
 }
 
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+	return atomic.AddUint64((*uint64)(u), 1)
+}
+
 // Stack is a networking stack, with all supported protocols, NICs, and route
 // table.
 type Stack struct {
@@ -411,6 +419,14 @@ type Stack struct {
 	// ndpDisp is the NDP event dispatcher that is used to send the netstack
 	// integrator NDP related events.
 	ndpDisp NDPDispatcher
+
+	// uniqueIDGenerator is a generator of unique identifiers.
+	uniqueIDGenerator UniqueID
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+	UniqueID() uint64
 }
 
 // Options contains optional Stack configuration.
@@ -434,6 +450,9 @@ type Options struct {
 	// stack (false).
 	HandleLocal bool
 
+	// UniqueID is an optional generator of unique identifiers.
+	UniqueID UniqueID
+
 	// NDPConfigs is the default NDP configurations used by interfaces.
 	//
 	// By default, NDPConfigs will have a zero value for its
@@ -506,6 +525,10 @@ func New(opts Options) *Stack {
 		clock = &tcpip.StdClock{}
 	}
 
+	if opts.UniqueID == nil {
+		opts.UniqueID = new(uniqueIDGenerator)
+	}
+
 	// Make sure opts.NDPConfigs contains valid values only.
 	opts.NDPConfigs.validate()
 
@@ -524,6 +547,7 @@ func New(opts Options) *Stack {
 		portSeed:             generateRandUint32(),
 		ndpConfigs:           opts.NDPConfigs,
 		autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+		uniqueIDGenerator:    opts.UniqueID,
 		ndpDisp:              opts.NDPDisp,
 	}
 
@@ -551,6 +575,11 @@ func New(opts Options) *Stack {
 	return s
 }
 
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+	return s.uniqueIDGenerator.UniqueID()
+}
+
 // SetNetworkProtocolOption allows configuring individual protocol level
 // options. This method returns an error if the protocol is not supported or
 // option is not supported by the protocol implementation or the provided value
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f633632f0..ccd3d030e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,6 +17,7 @@ package stack
 import (
 	"fmt"
 	"math/rand"
+	"sort"
 	"sync"
 
 	"gvisor.dev/gvisor/pkg/tcpip"
@@ -310,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
 	// endpointsMap. This will allow us to remove endpoint from the array fast.
 	ep.endpointsMap[t] = len(ep.endpointsArr)
 	ep.endpointsArr = append(ep.endpointsArr, t)
+
+	// ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
+	// can be restored in the same order.
+	sort.Slice(ep.endpointsArr, func(i, j int) bool {
+		return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
+	})
+	for i, e := range ep.endpointsArr {
+		ep.endpointsMap[e] = i
+	}
 	return nil
 }
 
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index ae6fda3a9..203e79f56 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -43,6 +43,7 @@ type fakeTransportEndpoint struct {
 	proto    *fakeTransportProtocol
 	peerAddr tcpip.Address
 	route    stack.Route
+	uniqueID uint64
 
 	// acceptQueue is non-nil iff bound.
 	acceptQueue []fakeTransportEndpoint
@@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
 	return nil
 }
 
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
-	return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+	return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
 }
 
 func (f *fakeTransportEndpoint) Close() {
@@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
 	return nil
 }
 
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+	return f.uniqueID
+}
+
 func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
 	return nil
 }
@@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
 }
 
 func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
-	return newFakeTransportEndpoint(stack, f, netProto), nil
+	return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
 }
 
 func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
-- 
cgit v1.2.3