summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2018-12-28 11:26:01 -0800
committerShentubot <shentubot@google.com>2018-12-28 11:27:14 -0800
commit652d068119052b0b3bc4a0808a4400a22380a30b (patch)
treef5a617063151ffb9563ebbcd3189611e854952db
parenta3217b71723a93abb7a2aca535408ab84d81ac2f (diff)
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port. Incoming packets are distributed to sockets using a hash based on source and destination addresses. This means that all packets from one sender will be received by the same server socket. PiperOrigin-RevId: 227153413 Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go20
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go5
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD21
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins.go80
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins_test.go176
-rw-r--r--pkg/tcpip/ports/BUILD4
-rw-r--r--pkg/tcpip/ports/ports.go74
-rw-r--r--pkg/tcpip/ports/ports_test.go134
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go144
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/tcpip.go4
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go34
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go85
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc289
-rw-r--r--test/syscalls/syscall_test_runner.go1
21 files changed, 1025 insertions, 105 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 1b9c75949..d65b5f49e 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -634,6 +634,18 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_REUSEPORT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.ReusePortOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -900,6 +912,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+ case linux.SO_REUSEPORT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 257bc2d71..8c8ebadb7 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -285,7 +285,10 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
if blocking && se == syserr.ErrTryAgain {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventIn)
+ // FIXME: This waiter.EventHUp is a partial
+ // measure, need to figure out how to translate linux events to
+ // internal events.
+ s.EventRegister(&e, waiter.EventIn|waiter.EventHUp)
defer s.EventUnregister(&e)
// Try to accept the connection again; if it fails, then wait until we
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
new file mode 100644
index 000000000..bbb764db8
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -0,0 +1,21 @@
+load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"]) # Apache 2.0
+
+go_library(
+ name = "jenkins",
+ srcs = ["jenkins.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins",
+ visibility = [
+ "//visibility:public",
+ ],
+)
+
+go_test(
+ name = "jenkins_test",
+ size = "small",
+ srcs = [
+ "jenkins_test.go",
+ ],
+ embed = [":jenkins"],
+)
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go
new file mode 100644
index 000000000..e66d5f12b
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins.go
@@ -0,0 +1,80 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash
+// functions created by by Bob Jenkins.
+//
+// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1
+//
+package jenkins
+
+import (
+ "hash"
+)
+
+// Sum32 represents Jenkins's one_at_a_time hash.
+//
+// Use the Sum32 type directly (as opposed to New32 below)
+// to avoid allocations.
+type Sum32 uint32
+
+// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash.
+//
+// Its Sum method will lay the value out in big-endian byte order.
+func New32() hash.Hash32 {
+ var s Sum32
+ return &s
+}
+
+// Reset resets the hash to its initial state.
+func (s *Sum32) Reset() { *s = 0 }
+
+// Sum32 returns the hash value
+func (s *Sum32) Sum32() uint32 {
+ hash := *s
+
+ hash += (hash << 3)
+ hash ^= hash >> 11
+ hash += hash << 15
+
+ return uint32(hash)
+}
+
+// Write adds more data to the running hash.
+//
+// It never returns an error.
+func (s *Sum32) Write(data []byte) (int, error) {
+ hash := *s
+ for _, b := range data {
+ hash += Sum32(b)
+ hash += hash << 10
+ hash ^= hash >> 6
+ }
+ *s = hash
+ return len(data), nil
+}
+
+// Size returns the number of bytes Sum will return.
+func (s *Sum32) Size() int { return 4 }
+
+// BlockSize returns the hash's underlying block size.
+func (s *Sum32) BlockSize() int { return 1 }
+
+// Sum appends the current hash to in and returns the resulting slice.
+//
+// It does not change the underlying hash state.
+func (s *Sum32) Sum(in []byte) []byte {
+ v := s.Sum32()
+ return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go
new file mode 100644
index 000000000..9d86174aa
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins_test.go
@@ -0,0 +1,176 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package jenkins
+
+import (
+ "bytes"
+ "encoding/binary"
+ "hash"
+ "hash/fnv"
+ "math"
+ "testing"
+)
+
+func TestGolden32(t *testing.T) {
+ var golden32 = []struct {
+ out []byte
+ in string
+ }{
+ {[]byte{0x00, 0x00, 0x00, 0x00}, ""},
+ {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"},
+ {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"},
+ {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"},
+ }
+
+ hash := New32()
+
+ for _, g := range golden32 {
+ hash.Reset()
+ done, error := hash.Write([]byte(g.in))
+ if error != nil {
+ t.Fatalf("write error: %s", error)
+ }
+ if done != len(g.in) {
+ t.Fatalf("wrote only %d out of %d bytes", done, len(g.in))
+ }
+ if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) {
+ t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out)
+ }
+ }
+}
+
+func TestIntegrity32(t *testing.T) {
+ data := []byte{'1', '2', 3, 4, 5}
+
+ h := New32()
+ h.Write(data)
+ sum := h.Sum(nil)
+
+ if size := h.Size(); size != len(sum) {
+ t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum))
+ }
+
+ if a := h.Sum(nil); !bytes.Equal(sum, a) {
+ t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a)
+ }
+
+ h.Reset()
+ h.Write(data)
+ if a := h.Sum(nil); !bytes.Equal(sum, a) {
+ t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a)
+ }
+
+ h.Reset()
+ h.Write(data[:2])
+ h.Write(data[2:])
+ if a := h.Sum(nil); !bytes.Equal(sum, a) {
+ t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a)
+ }
+
+ sum32 := h.(hash.Hash32).Sum32()
+ if sum32 != binary.BigEndian.Uint32(sum) {
+ t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32)
+ }
+}
+
+func BenchmarkJenkins32KB(b *testing.B) {
+ h := New32()
+
+ b.SetBytes(1024)
+ data := make([]byte, 1024)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ in := make([]byte, 0, h.Size())
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ h.Reset()
+ h.Write(data)
+ h.Sum(in)
+ }
+}
+
+func BenchmarkFnv32(b *testing.B) {
+ arr := make([]int64, 1000)
+ for i := 0; i < b.N; i++ {
+ var payload [8]byte
+ binary.BigEndian.PutUint32(payload[:4], uint32(i))
+ binary.BigEndian.PutUint32(payload[4:], uint32(i))
+
+ h := fnv.New32()
+ h.Write(payload[:])
+ idx := int(h.Sum32()) % len(arr)
+ arr[idx]++
+ }
+ b.StopTimer()
+ c := 0
+ if b.N > 1000000 {
+ for i := 0; i < len(arr)-1; i++ {
+ if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+ if c == 0 {
+ b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N)
+ }
+ c++
+ }
+ }
+ if c > 0 {
+ b.Logf("Unbalanced buckets: %d", c)
+ }
+ }
+}
+
+func BenchmarkSum32(b *testing.B) {
+ arr := make([]int64, 1000)
+ for i := 0; i < b.N; i++ {
+ var payload [8]byte
+ binary.BigEndian.PutUint32(payload[:4], uint32(i))
+ binary.BigEndian.PutUint32(payload[4:], uint32(i))
+ h := Sum32(0)
+ h.Write(payload[:])
+ idx := int(h.Sum32()) % len(arr)
+ arr[idx]++
+ }
+ b.StopTimer()
+ if b.N > 1000000 {
+ for i := 0; i < len(arr)-1; i++ {
+ if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+ b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
+ break
+ }
+ }
+ }
+}
+
+func BenchmarkNew32(b *testing.B) {
+ arr := make([]int64, 1000)
+ for i := 0; i < b.N; i++ {
+ var payload [8]byte
+ binary.BigEndian.PutUint32(payload[:4], uint32(i))
+ binary.BigEndian.PutUint32(payload[4:], uint32(i))
+ h := New32()
+ h.Write(payload[:])
+ idx := int(h.Sum32()) % len(arr)
+ arr[idx]++
+ }
+ b.StopTimer()
+ if b.N > 1000000 {
+ for i := 0; i < len(arr)-1; i++ {
+ if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
+ b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
+ break
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index c69fc0744..a2fa9b84a 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -7,7 +7,9 @@ go_library(
srcs = ["ports.go"],
importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/ports",
visibility = ["//:sandbox"],
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ ],
)
go_test(
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 41ef32921..d212a5792 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -42,23 +42,47 @@ type PortManager struct {
allocatedPorts map[portDescriptor]bindAddresses
}
+type portNode struct {
+ reuse bool
+ refs int
+}
+
// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]struct{}
+type bindAddresses map[tcpip.Address]portNode
// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address) bool {
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
if addr == anyIPAddress {
- return len(b) == 0
+ if len(b) == 0 {
+ return true
+ }
+ if !reuse {
+ return false
+ }
+ for _, n := range b {
+ if !n.reuse {
+ return false
+ }
+ }
+ return true
}
// If all addresses for this portDescriptor are already bound, no
// address is available.
- if _, ok := b[anyIPAddress]; ok {
- return false
+ if n, ok := b[anyIPAddress]; ok {
+ if !reuse {
+ return false
+ }
+ if !n.reuse {
+ return false
+ }
}
- if _, ok := b[addr]; ok {
- return false
+ if n, ok := b[addr]; ok {
+ if !reuse {
+ return false
+ }
+ return n.reuse
}
return true
}
@@ -92,17 +116,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr) {
+ if !addrs.isAvailable(addr, reuse) {
return false
}
}
@@ -114,14 +138,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -129,13 +153,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
return false
}
@@ -147,7 +171,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- m[addr] = struct{}{}
+ if n, ok := m[addr]; ok {
+ n.refs++
+ m[addr] = n
+ } else {
+ m[addr] = portNode{reuse: reuse, refs: 1}
+ }
}
return true
@@ -162,7 +191,16 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- delete(m, addr)
+ n, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n.refs--
+ if n.refs == 0 {
+ delete(m, addr)
+ } else {
+ m[addr] = n
+ }
if len(m) == 0 {
delete(s.allocatedPorts, desc)
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 72577dfcb..01e7320b4 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -28,67 +28,99 @@ const (
fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
)
-func TestPortReservation(t *testing.T) {
- pm := NewPortManager()
- net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+type portReserveTestAction struct {
+ port uint16
+ ip tcpip.Address
+ want *tcpip.Error
+ reuse bool
+ release bool
+}
+func TestPortReservation(t *testing.T) {
for _, test := range []struct {
- port uint16
- ip tcpip.Address
- want *tcpip.Error
+ tname string
+ actions []portReserveTestAction
}{
{
- port: 80,
- ip: fakeIPAddress,
- want: nil,
- },
- {
- port: 80,
- ip: fakeIPAddress1,
- want: nil,
- },
- {
- /* N.B. Order of tests matters! */
- port: 80,
- ip: anyIPAddress,
- want: tcpip.ErrPortInUse,
- },
- {
- port: 22,
- ip: anyIPAddress,
- want: nil,
- },
- {
- port: 22,
- ip: fakeIPAddress,
- want: tcpip.ErrPortInUse,
- },
- {
- port: 0,
- ip: fakeIPAddress,
- want: nil,
+ tname: "bind to ip",
+ actions: []portReserveTestAction{
+ {port: 80, ip: fakeIPAddress, want: nil},
+ {port: 80, ip: fakeIPAddress1, want: nil},
+ /* N.B. Order of tests matters! */
+ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
+ {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ },
},
{
- port: 0,
- ip: fakeIPAddress,
- want: nil,
+ tname: "bind to inaddr any",
+ actions: []portReserveTestAction{
+ {port: 22, ip: anyIPAddress, want: nil},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ /* release fakeIPAddress, but anyIPAddress is still inuse */
+ {port: 22, ip: fakeIPAddress, release: true},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
+ {port: 22, ip: anyIPAddress, want: nil, release: true},
+ {port: 22, ip: fakeIPAddress, want: nil},
+ },
+ }, {
+ tname: "bind to zero port",
+ actions: []portReserveTestAction{
+ {port: 00, ip: fakeIPAddress, want: nil},
+ {port: 00, ip: fakeIPAddress, want: nil},
+ {port: 00, ip: fakeIPAddress, reuse: true, want: nil},
+ },
+ }, {
+ tname: "bind to ip with reuseport",
+ actions: []portReserveTestAction{
+ {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+
+ {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+
+ {port: 25, ip: anyIPAddress, reuse: true, want: nil},
+ },
+ }, {
+ tname: "bind to inaddr any with reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 24, ip: anyIPAddress, reuse: true, want: nil},
+
+ {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, release: true, want: nil},
+
+ {port: 24, ip: anyIPAddress, release: true},
+ {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+
+ {port: 24, ip: anyIPAddress, release: true},
+ {port: 24, ip: anyIPAddress, reuse: false, want: nil},
+ },
},
} {
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port)
- if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want)
- }
- if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
- }
- }
+ t.Run(test.tname, func(t *testing.T) {
+ pm := NewPortManager()
+ net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
- // Release port 22 from any IP address, then try to reserve fake IP
- // address on 22.
- pm.ReleasePort(net, fakeTransNumber, anyIPAddress, 22)
+ for _, test := range test.actions {
+ if test.release {
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ continue
+ }
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ if err != test.want {
+ t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ }
+ if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
+ t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+ }
+ }
+ })
- if port, err := pm.ReservePort(net, fakeTransNumber, fakeIPAddress, 22); port != 22 || err != nil {
- t.Fatalf("ReservePort(.., .., .., %d) = (port %d, err %v), want (22, nil); failed to reserve port after it should have been released", 22, port, err)
}
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 90cc05cda..9ff1c8731 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -22,6 +22,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0ac116675..7aa9dbd46 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -883,9 +883,9 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep)
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
}
s.mu.RLock()
@@ -896,14 +896,14 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N
return tcpip.ErrUnknownNICID
}
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep)
+ return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id)
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
return
}
@@ -912,7 +912,7 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
nic := s.nics[nicID]
if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id)
+ nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index c8522ad9e..a5ff2159a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,10 +15,12 @@
package stack
import (
+ "math/rand"
"sync"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
)
@@ -34,6 +36,23 @@ type transportEndpoints struct {
endpoints map[TransportEndpointID]TransportEndpoint
}
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ e, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if multiPortEp, ok := e.(*multiPortEndpoint); ok {
+ if !multiPortEp.unregisterEndpoint(ep) {
+ return
+ }
+ }
+ delete(eps.endpoints, id)
+}
+
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
@@ -57,10 +76,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
return err
}
}
@@ -68,7 +87,97 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port.
+type multiPortEndpoint struct {
+ mu sync.RWMutex
+ endpointsArr []TransportEndpoint
+ endpointsMap map[TransportEndpoint]int
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+ return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8),
+ }
+
+ h := jenkins.Sum32(ep.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+ hash := h.Sum32()
+
+ idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
+ return ep.endpointsArr[idx]
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+}
+
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // A new endpoint is added into endpointsArr and its index there is
+ // saved in endpointsMap. This will allows to remove endpoint from
+ // the array fast.
+ ep.endpointsMap[ep] = len(ep.endpointsArr)
+ ep.endpointsArr = append(ep.endpointsArr, t)
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ idx, ok := ep.endpointsMap[t]
+ if !ok {
+ return false
+ }
+ delete(ep.endpointsMap, t)
+ l := len(ep.endpointsArr)
+ if l > 1 {
+ // The last endpoint in endpointsArr is moved instead of the deleted one.
+ lastEp := ep.endpointsArr[l-1]
+ ep.endpointsArr[idx] = lastEp
+ ep.endpointsMap[lastEp] = idx
+ ep.endpointsArr = ep.endpointsArr[0 : l-1]
+ return false
+ }
+ return true
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if id.RemotePort != 0 {
+ reusePort = false
+ }
+
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return nil
@@ -77,10 +186,29 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
+ var multiPortEp *multiPortEndpoint
if _, ok := eps.endpoints[id]; ok {
- return tcpip.ErrPortInUse
+ if !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
+ if !ok {
+ return tcpip.ErrPortInUse
+ }
}
+ if reusePort {
+ if multiPortEp == nil {
+ multiPortEp = &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.seed = rand.Uint32()
+ eps.endpoints[id] = multiPortEp
+ }
+
+ multiPortEp.singleRegisterEndpoint(ep)
+
+ return nil
+ }
eps.endpoints[id] = ep
return nil
@@ -88,12 +216,10 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.mu.Lock()
- delete(eps.endpoints, id)
- eps.mu.Unlock()
+ eps.unregisterEndpoint(id, ep)
}
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f09760180..022207081 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -107,7 +107,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
if err != nil {
return err
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 627786808..7d4fbe075 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -436,6 +436,10 @@ type CorkOption int
// should allow reuse of local address.
type ReuseAddressOption int
+// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
+// to be bound to an identical socket address.
+type ReusePortOption int
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index d1b9b136c..29f6c543d 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -100,7 +100,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
}
// Close the receive list and drain it.
@@ -541,14 +541,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
switch err {
case nil:
return true, nil
@@ -597,7 +597,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
- e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id)
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id, e)
return err
}
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d0e1d6782..78d2c76e0 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -215,7 +215,7 @@ func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, ir
n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
n.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d4eda50ec..5281f8be2 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -162,6 +162,9 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
+ // reusePort is set to true if SO_REUSEPORT is enabled.
+ reusePort bool
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -416,7 +419,7 @@ func (e *endpoint) Close() {
e.isPortReserved = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.isRegistered = false
}
}
@@ -453,7 +456,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
}
e.route.Release()
@@ -681,6 +684,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
case tcpip.QuickAckOption:
if v == 0 {
atomic.StoreUint32(&e.slowAck, 1)
@@ -875,6 +884,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1057,7 +1077,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
if err != nil {
return err
}
@@ -1071,13 +1091,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) {
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
case nil:
e.id = id
return true, nil
@@ -1234,7 +1254,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
return err
}
@@ -1315,7 +1335,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 67e9ca0ac..b2a27a7cb 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -81,6 +81,7 @@ type endpoint struct {
dstPort uint16
v6only bool
multicastTTL uint8
+ reusePort bool
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -132,7 +133,7 @@ func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.Transport
ep := newEndpoint(stack, r.NetProto, waiterQueue)
// Register new endpoint so that packets are routed to it.
- if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
+ if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep, ep.reusePort); err != nil {
ep.Close()
return nil, err
}
@@ -155,7 +156,7 @@ func (e *endpoint) Close() {
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
@@ -449,6 +450,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
break
}
}
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
}
return nil
}
@@ -513,6 +520,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -648,7 +666,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
}
e.id = id
@@ -711,14 +729,14 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
}
@@ -766,7 +784,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
- e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id, e)
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
return err
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 58a346cd9..2a9cf4b57 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "math"
"math/rand"
"testing"
"time"
@@ -254,6 +255,90 @@ func newPayload() []byte {
return b
}
+func TestBindPortReuse(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ var eps [5]tcpip.Endpoint
+ reusePortOpt := tcpip.ReusePortOption(1)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i := 0; i < len(eps); i++ {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(eps[i])
+
+ defer eps[i].Close()
+ if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, nil); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload, &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort,
+ })
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v", err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled
+ // by the same socket.
+ if ports[port] != ep {
+ t.Fatalf("Port mismatch")
+ }
+ }
+ }
+
+ if len(stats) != len(eps) {
+ t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ }
+
+ // Check that a packet distribution is fair between sockets.
+ for _, c := range stats {
+ n := float64(npackets) / float64(len(eps))
+ // The deviation is less than 10%.
+ if math.Abs(float64(c)-n) > n/10 {
+ t.Fatal(c, n)
+ }
+ }
+}
+
func testV4Read(c *testContext) {
// Send a packet.
payload := newPayload()
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index ae33d14da..f0e61e083 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2163,9 +2163,13 @@ cc_binary(
":socket_test_util",
"//test/util:file_descriptor",
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:test_main",
"//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 17a46e149..0893be5a7 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -17,17 +17,24 @@
#include <string.h>
#include <sys/socket.h>
+#include <atomic>
+#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -227,6 +234,238 @@ INSTANTIATE_TEST_CASE_P(
TestParam{V6Loopback(), V6Loopback()}),
DescribeTestParam);
+using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>;
+
+TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // Create the listening socket.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) continue;
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::unique_ptr<ScopedThread> listen_thread[kThreadCount];
+ int accept_counts[kThreadCount] = {};
+ // TODO: figure how to not disable S/R for the whole test.
+ // We need to take into account that this test executes a lot of system
+ // calls from many threads.
+ DisableSave ds;
+
+ for (int i = 0; i < kThreadCount; i++) {
+ listen_thread[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, i, &connects_received]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ if (connects_received >= kConnectAttempts) {
+ // Another thread have shutdown our read side causing the
+ // accept to fail.
+ break;
+ }
+ ASSERT_NO_ERRNO(fd);
+ break;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ int data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (int j = 0; j < kThreadCount; j++) {
+ shutdown(listener_fds[j].get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ ScopedThread connecting_thread([&connector, &conn_addr]() {
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ });
+
+ // Join threads to be sure that all connections have been counted
+ connecting_thread.Join();
+ for (int i = 0; i < kThreadCount; i++) {
+ listen_thread[i]->Join();
+ }
+ // Check that connections are distributed fairly between listening sockets
+ for (int i = 0; i < kThreadCount; i++)
+ EXPECT_THAT(accept_counts[i],
+ EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
+}
+
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // Create the listening socket.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) continue;
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::unique_ptr<ScopedThread> receiver_thread[kThreadCount];
+ int packets_per_socket[kThreadCount] = {};
+ // TODO: figure how to not disable S/R for the whole test.
+ DisableSave ds; // Too expensive.
+
+ for (int i = 0; i < kThreadCount; i++) {
+ receiver_thread[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, i, &packets_received]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (int j = 0; j < kThreadCount; j++)
+ shutdown(listener_fds[j].get(), SHUT_RDWR);
+ });
+ }
+
+ ScopedThread main_thread([&connector, &conn_addr]() {
+ for (int i = 0; i < kConnectAttempts; i++) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+ });
+
+ main_thread.Join();
+
+ // Join threads to be sure that all connections have been counted
+ for (int i = 0; i < kThreadCount; i++) {
+ receiver_thread[i]->Join();
+ }
+ // Check that packets are distributed fairly between listening sockets.
+ for (int i = 0; i < kThreadCount; i++)
+ EXPECT_THAT(packets_per_socket[i],
+ EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ All, SocketInetReusePortTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
struct ProtocolTestParam {
std::string description;
int type;
@@ -806,6 +1045,56 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
}
}
+TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
+ auto const& param = GetParam();
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage addr = test_addr.addr;
+
+ for (int i = 0; i < 2; i++) {
+ const int portreuse1 = i % 2;
+ auto s1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd1 = s1.get();
+ socklen_t addrlen = test_addr.addr_len;
+
+ EXPECT_THAT(
+ setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(fd1, 1), SyscallSucceeds());
+ }
+
+ // j is less than 4 to check that the port reuse logic works correctly after
+ // closing bound sockets.
+ for (int j = 0; j < 4; j++) {
+ const int portreuse2 = j % 2;
+ auto s2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ int fd2 = s2.get();
+
+ EXPECT_THAT(
+ setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)),
+ SyscallSucceeds());
+
+ LOG(INFO) << portreuse1 << " " << portreuse2;
+ int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
+
+ // Verify that two sockets can be bound to the same port only if
+ // SO_REUSEPORT is set for both of them.
+ if (!portreuse1 || !portreuse2)
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE));
+ else
+ ASSERT_THAT(ret, SyscallSucceeds());
+ }
+ }
+}
+
INSTANTIATE_TEST_CASE_P(AllFamlies, SocketMultiProtocolInetLoopbackTest,
::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM},
ProtocolTestParam{"UDP", SOCK_DGRAM}),
diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go
index 9ee0361ee..ec048f10f 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/syscalls/syscall_test_runner.go
@@ -118,6 +118,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Mark the root as writeable, as some tests attempt to
// write to the rootfs, and expect EACCES, not EROFS.
spec.Root.Readonly = false
+ spec.Mounts = nil
// Set environment variable that indicates we are
// running in gVisor and with the given platform.