summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/registration.go3
-rw-r--r--pkg/tcpip/stack/stack.go29
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go10
-rw-r--r--pkg/tcpip/stack/transport_test.go11
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go7
-rw-r--r--runsc/boot/loader.go5
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc107
9 files changed, 181 insertions, 5 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0360187b8..94015ba54 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -60,6 +60,9 @@ const (
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6d6ddc0ff..115a6fcb8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"encoding/binary"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -344,6 +345,13 @@ type ResumableEndpoint interface {
Resume(*Stack)
}
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -411,6 +419,14 @@ type Stack struct {
// ndpDisp is the NDP event dispatcher that is used to send the netstack
// integrator NDP related events.
ndpDisp NDPDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
}
// Options contains optional Stack configuration.
@@ -434,6 +450,9 @@ type Options struct {
// stack (false).
HandleLocal bool
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
+
// NDPConfigs is the default NDP configurations used by interfaces.
//
// By default, NDPConfigs will have a zero value for its
@@ -506,6 +525,10 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
@@ -524,6 +547,7 @@ func New(opts Options) *Stack {
portSeed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
}
@@ -551,6 +575,11 @@ func New(opts Options) *Stack {
return s
}
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f633632f0..ccd3d030e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"math/rand"
+ "sort"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -310,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
// endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+
+ // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
+ // can be restored in the same order.
+ sort.Slice(ep.endpointsArr, func(i, j int) bool {
+ return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
+ })
+ for i, e := range ep.endpointsArr {
+ ep.endpointsMap[e] = i
+ }
return nil
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index ae6fda3a9..203e79f56 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -43,6 +43,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+ uniqueID uint64
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
@@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Close() {
@@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d0dd383fd..114a69b4e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -58,6 +58,7 @@ type endpoint struct {
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -90,9 +91,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: stateInitial,
+ uniqueID: s.UniqueID(),
}, nil
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 8a3ca0f1b..a1efd8d55 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -287,6 +287,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -504,6 +505,11 @@ type endpoint struct {
stats Stats `state:"nosave"`
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// calculateAdvertisedMSS calculates the MSS to advertise.
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
@@ -565,6 +571,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
interval: 75 * time.Second,
count: 9,
},
+ uniqueID: s.UniqueID(),
}
var ss SendBufferSizeOption
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index cda302bb7..68977dc25 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -80,6 +80,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -160,9 +161,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: StateInitial,
+ uniqueID: s.UniqueID(),
}
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 0c0eba99e..86df384f8 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -232,7 +232,7 @@ func New(args Args) (*Loader, error) {
// this point. Netns is configured before Run() is called. Netstack is
// configured using a control uRPC message. Host network is configured inside
// Run().
- networkStack, err := newEmptyNetworkStack(args.Conf, k)
+ networkStack, err := newEmptyNetworkStack(args.Conf, k, k)
if err != nil {
return nil, fmt.Errorf("creating network: %v", err)
}
@@ -905,7 +905,7 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
return l.k.GlobalInit().ExitStatus()
}
-func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
+func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
switch conf.Network {
case NetworkHost:
return hostinet.NewStack(), nil
@@ -923,6 +923,7 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
// Enable raw sockets for users with sufficient
// privileges.
RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
})}
// Enable SACK Recovery.
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 322ee07ad..ab375aaaf 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -16,6 +16,7 @@
#include <netinet/in.h>
#include <poll.h>
#include <string.h>
+#include <sys/epoll.h>
#include <sys/socket.h>
#include <atomic>
@@ -516,6 +517,112 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // TODO(b/141211329): endpointsByNic.seed has to be saved/restored.
+ const DisableSave ds141211329;
+
+ // Create listening sockets.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10;
+ FileDescriptor client_fds[kConnectAttempts];
+
+ // Do the first run without save/restore.
+ DisableSave ds;
+ for (int i = 0; i < kConnectAttempts; i++) {
+ client_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ ds.reset();
+
+ // Check that a mapping of client and server sockets has
+ // not been change after save/restore.
+ for (int i = 0; i < kConnectAttempts; i++) {
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ int epollfd;
+ ASSERT_THAT(epollfd = epoll_create1(0), SyscallSucceeds());
+
+ for (int i = 0; i < kThreadCount; i++) {
+ int fd = listener_fds[i].get();
+ struct epoll_event ev;
+ ev.data.fd = fd;
+ ev.events = EPOLLIN;
+ ASSERT_THAT(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev), SyscallSucceeds());
+ }
+
+ std::map<uint16_t, int> portToFD;
+
+ for (int i = 0; i < kConnectAttempts * 2; i++) {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ struct epoll_event ev;
+ int data, fd;
+
+ ASSERT_THAT(epoll_wait(epollfd, &ev, 1, -1), SyscallSucceedsWithValue(1));
+
+ fd = ev.data.fd;
+ EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr));
+ auto prev_port = portToFD.find(port);
+ // Check that all packets from one client have been delivered to the same
+ // server socket.
+ if (prev_port == portToFD.end()) {
+ portToFD[port] = ev.data.fd;
+ } else {
+ EXPECT_EQ(portToFD[port], ev.data.fd);
+ }
+ }
+}
+
INSTANTIATE_TEST_SUITE_P(
All, SocketInetReusePortTest,
::testing::Values(