diff options
author | Andrei Vagin <avagin@google.com> | 2018-12-28 11:26:01 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-28 11:27:14 -0800 |
commit | 652d068119052b0b3bc4a0808a4400a22380a30b (patch) | |
tree | f5a617063151ffb9563ebbcd3189611e854952db | |
parent | a3217b71723a93abb7a2aca535408ab84d81ac2f (diff) |
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port.
Incoming packets are distributed to sockets using a hash based on source and
destination addresses. This means that all packets from one sender will be
received by the same server socket.
PiperOrigin-RevId: 227153413
Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 20 | ||||
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/hash/jenkins/BUILD | 21 | ||||
-rw-r--r-- | pkg/tcpip/hash/jenkins/jenkins.go | 80 | ||||
-rw-r--r-- | pkg/tcpip/hash/jenkins/jenkins_test.go | 176 | ||||
-rw-r--r-- | pkg/tcpip/ports/BUILD | 4 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports_test.go | 134 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 144 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 85 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 4 | ||||
-rw-r--r-- | test/syscalls/linux/socket_inet_loopback.cc | 289 | ||||
-rw-r--r-- | test/syscalls/syscall_test_runner.go | 1 |
21 files changed, 1025 insertions, 105 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 1b9c75949..d65b5f49e 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -634,6 +634,18 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_REUSEPORT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.ReusePortOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -900,6 +912,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + case linux.SO_REUSEPORT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 257bc2d71..8c8ebadb7 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -285,7 +285,10 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if blocking && se == syserr.ErrTryAgain { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventIn) + // FIXME: This waiter.EventHUp is a partial + // measure, need to figure out how to translate linux events to + // internal events. + s.EventRegister(&e, waiter.EventIn|waiter.EventHUp) defer s.EventUnregister(&e) // Try to accept the connection again; if it fails, then wait until we diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD new file mode 100644 index 000000000..bbb764db8 --- /dev/null +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -0,0 +1,21 @@ +load("//tools/go_stateify:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) # Apache 2.0 + +go_library( + name = "jenkins", + srcs = ["jenkins.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins", + visibility = [ + "//visibility:public", + ], +) + +go_test( + name = "jenkins_test", + size = "small", + srcs = [ + "jenkins_test.go", + ], + embed = [":jenkins"], +) diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go new file mode 100644 index 000000000..e66d5f12b --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins.go @@ -0,0 +1,80 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package jenkins implements Jenkins's one_at_a_time, non-cryptographic hash +// functions created by by Bob Jenkins. +// +// See https://en.wikipedia.org/wiki/Jenkins_hash_function#cite_note-dobbsx-1 +// +package jenkins + +import ( + "hash" +) + +// Sum32 represents Jenkins's one_at_a_time hash. +// +// Use the Sum32 type directly (as opposed to New32 below) +// to avoid allocations. +type Sum32 uint32 + +// New32 returns a new 32-bit Jenkins's one_at_a_time hash.Hash. +// +// Its Sum method will lay the value out in big-endian byte order. +func New32() hash.Hash32 { + var s Sum32 + return &s +} + +// Reset resets the hash to its initial state. +func (s *Sum32) Reset() { *s = 0 } + +// Sum32 returns the hash value +func (s *Sum32) Sum32() uint32 { + hash := *s + + hash += (hash << 3) + hash ^= hash >> 11 + hash += hash << 15 + + return uint32(hash) +} + +// Write adds more data to the running hash. +// +// It never returns an error. +func (s *Sum32) Write(data []byte) (int, error) { + hash := *s + for _, b := range data { + hash += Sum32(b) + hash += hash << 10 + hash ^= hash >> 6 + } + *s = hash + return len(data), nil +} + +// Size returns the number of bytes Sum will return. +func (s *Sum32) Size() int { return 4 } + +// BlockSize returns the hash's underlying block size. +func (s *Sum32) BlockSize() int { return 1 } + +// Sum appends the current hash to in and returns the resulting slice. +// +// It does not change the underlying hash state. +func (s *Sum32) Sum(in []byte) []byte { + v := s.Sum32() + return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go new file mode 100644 index 000000000..9d86174aa --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins_test.go @@ -0,0 +1,176 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package jenkins + +import ( + "bytes" + "encoding/binary" + "hash" + "hash/fnv" + "math" + "testing" +) + +func TestGolden32(t *testing.T) { + var golden32 = []struct { + out []byte + in string + }{ + {[]byte{0x00, 0x00, 0x00, 0x00}, ""}, + {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"}, + {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"}, + {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"}, + } + + hash := New32() + + for _, g := range golden32 { + hash.Reset() + done, error := hash.Write([]byte(g.in)) + if error != nil { + t.Fatalf("write error: %s", error) + } + if done != len(g.in) { + t.Fatalf("wrote only %d out of %d bytes", done, len(g.in)) + } + if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) { + t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out) + } + } +} + +func TestIntegrity32(t *testing.T) { + data := []byte{'1', '2', 3, 4, 5} + + h := New32() + h.Write(data) + sum := h.Sum(nil) + + if size := h.Size(); size != len(sum) { + t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum)) + } + + if a := h.Sum(nil); !bytes.Equal(sum, a) { + t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a) + } + + h.Reset() + h.Write(data) + if a := h.Sum(nil); !bytes.Equal(sum, a) { + t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a) + } + + h.Reset() + h.Write(data[:2]) + h.Write(data[2:]) + if a := h.Sum(nil); !bytes.Equal(sum, a) { + t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a) + } + + sum32 := h.(hash.Hash32).Sum32() + if sum32 != binary.BigEndian.Uint32(sum) { + t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32) + } +} + +func BenchmarkJenkins32KB(b *testing.B) { + h := New32() + + b.SetBytes(1024) + data := make([]byte, 1024) + for i := range data { + data[i] = byte(i) + } + in := make([]byte, 0, h.Size()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + h.Reset() + h.Write(data) + h.Sum(in) + } +} + +func BenchmarkFnv32(b *testing.B) { + arr := make([]int64, 1000) + for i := 0; i < b.N; i++ { + var payload [8]byte + binary.BigEndian.PutUint32(payload[:4], uint32(i)) + binary.BigEndian.PutUint32(payload[4:], uint32(i)) + + h := fnv.New32() + h.Write(payload[:]) + idx := int(h.Sum32()) % len(arr) + arr[idx]++ + } + b.StopTimer() + c := 0 + if b.N > 1000000 { + for i := 0; i < len(arr)-1; i++ { + if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { + if c == 0 { + b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N) + } + c++ + } + } + if c > 0 { + b.Logf("Unbalanced buckets: %d", c) + } + } +} + +func BenchmarkSum32(b *testing.B) { + arr := make([]int64, 1000) + for i := 0; i < b.N; i++ { + var payload [8]byte + binary.BigEndian.PutUint32(payload[:4], uint32(i)) + binary.BigEndian.PutUint32(payload[4:], uint32(i)) + h := Sum32(0) + h.Write(payload[:]) + idx := int(h.Sum32()) % len(arr) + arr[idx]++ + } + b.StopTimer() + if b.N > 1000000 { + for i := 0; i < len(arr)-1; i++ { + if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { + b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) + break + } + } + } +} + +func BenchmarkNew32(b *testing.B) { + arr := make([]int64, 1000) + for i := 0; i < b.N; i++ { + var payload [8]byte + binary.BigEndian.PutUint32(payload[:4], uint32(i)) + binary.BigEndian.PutUint32(payload[4:], uint32(i)) + h := New32() + h.Write(payload[:]) + idx := int(h.Sum32()) % len(arr) + arr[idx]++ + } + b.StopTimer() + if b.N > 1000000 { + for i := 0; i < len(arr)-1; i++ { + if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { + b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) + break + } + } + } +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index c69fc0744..a2fa9b84a 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -7,7 +7,9 @@ go_library( srcs = ["ports.go"], importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/ports", visibility = ["//:sandbox"], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + ], ) go_test( diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 41ef32921..d212a5792 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -42,23 +42,47 @@ type PortManager struct { allocatedPorts map[portDescriptor]bindAddresses } +type portNode struct { + reuse bool + refs int +} + // bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]struct{} +type bindAddresses map[tcpip.Address]portNode // isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { if addr == anyIPAddress { - return len(b) == 0 + if len(b) == 0 { + return true + } + if !reuse { + return false + } + for _, n := range b { + if !n.reuse { + return false + } + } + return true } // If all addresses for this portDescriptor are already bound, no // address is available. - if _, ok := b[anyIPAddress]; ok { - return false + if n, ok := b[anyIPAddress]; ok { + if !reuse { + return false + } + if !n.reuse { + return false + } } - if _, ok := b[addr]; ok { - return false + if n, ok := b[addr]; ok { + if !reuse { + return false + } + return n.reuse } return true } @@ -92,17 +116,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port) + return s.isPortAvailableLocked(networks, transport, addr, port, reuse) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr) { + if !addrs.isAvailable(addr, reuse) { return false } } @@ -114,14 +138,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port) { + if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { return 0, tcpip.ErrPortInUse } return port, nil @@ -129,13 +153,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p), nil + return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { return false } @@ -147,7 +171,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - m[addr] = struct{}{} + if n, ok := m[addr]; ok { + n.refs++ + m[addr] = n + } else { + m[addr] = portNode{reuse: reuse, refs: 1} + } } return true @@ -162,7 +191,16 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - delete(m, addr) + n, ok := m[addr] + if !ok { + continue + } + n.refs-- + if n.refs == 0 { + delete(m, addr) + } else { + m[addr] = n + } if len(m) == 0 { delete(s.allocatedPorts, desc) } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 72577dfcb..01e7320b4 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -28,67 +28,99 @@ const ( fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") ) -func TestPortReservation(t *testing.T) { - pm := NewPortManager() - net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} +type portReserveTestAction struct { + port uint16 + ip tcpip.Address + want *tcpip.Error + reuse bool + release bool +} +func TestPortReservation(t *testing.T) { for _, test := range []struct { - port uint16 - ip tcpip.Address - want *tcpip.Error + tname string + actions []portReserveTestAction }{ { - port: 80, - ip: fakeIPAddress, - want: nil, - }, - { - port: 80, - ip: fakeIPAddress1, - want: nil, - }, - { - /* N.B. Order of tests matters! */ - port: 80, - ip: anyIPAddress, - want: tcpip.ErrPortInUse, - }, - { - port: 22, - ip: anyIPAddress, - want: nil, - }, - { - port: 22, - ip: fakeIPAddress, - want: tcpip.ErrPortInUse, - }, - { - port: 0, - ip: fakeIPAddress, - want: nil, + tname: "bind to ip", + actions: []portReserveTestAction{ + {port: 80, ip: fakeIPAddress, want: nil}, + {port: 80, ip: fakeIPAddress1, want: nil}, + /* N.B. Order of tests matters! */ + {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + }, }, { - port: 0, - ip: fakeIPAddress, - want: nil, + tname: "bind to inaddr any", + actions: []portReserveTestAction{ + {port: 22, ip: anyIPAddress, want: nil}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + /* release fakeIPAddress, but anyIPAddress is still inuse */ + {port: 22, ip: fakeIPAddress, release: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ + {port: 22, ip: anyIPAddress, want: nil, release: true}, + {port: 22, ip: fakeIPAddress, want: nil}, + }, + }, { + tname: "bind to zero port", + actions: []portReserveTestAction{ + {port: 00, ip: fakeIPAddress, want: nil}, + {port: 00, ip: fakeIPAddress, want: nil}, + {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + }, + }, { + tname: "bind to ip with reuseport", + actions: []portReserveTestAction{ + {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + + {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + + {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + }, + }, { + tname: "bind to inaddr any with reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + + {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + + {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, release: true, want: nil}, + + {port: 24, ip: anyIPAddress, release: true}, + {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + + {port: 24, ip: anyIPAddress, release: true}, + {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + }, }, } { - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want) - } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) - } - } + t.Run(test.tname, func(t *testing.T) { + pm := NewPortManager() + net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} - // Release port 22 from any IP address, then try to reserve fake IP - // address on 22. - pm.ReleasePort(net, fakeTransNumber, anyIPAddress, 22) + for _, test := range test.actions { + if test.release { + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + continue + } + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + if err != test.want { + t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + } + if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { + t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + } + } + }) - if port, err := pm.ReservePort(net, fakeTransNumber, fakeIPAddress, 22); port != 22 || err != nil { - t.Fatalf("ReservePort(.., .., .., %d) = (port %d, err %v), want (22, nil); failed to reserve port after it should have been released", 22, port, err) } } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 90cc05cda..9ff1c8731 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -22,6 +22,7 @@ go_library( "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ac116675..7aa9dbd46 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -883,9 +883,9 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep) + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) } s.mu.RLock() @@ -896,14 +896,14 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N return tcpip.ErrUnknownNICID } - return nic.demux.registerEndpoint(netProtos, protocol, id, ep) + return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id) + s.demux.unregisterEndpoint(netProtos, protocol, id, ep) return } @@ -912,7 +912,7 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip nic := s.nics[nicID] if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id) + nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c8522ad9e..a5ff2159a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,10 +15,12 @@ package stack import ( + "math/rand" "sync" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" ) @@ -34,6 +36,23 @@ type transportEndpoints struct { endpoints map[TransportEndpointID]TransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { + eps.mu.Lock() + defer eps.mu.Unlock() + e, ok := eps.endpoints[id] + if !ok { + return + } + if multiPortEp, ok := e.(*multiPortEndpoint); ok { + if !multiPortEp.unregisterEndpoint(ep) { + return + } + } + delete(eps.endpoints, id) +} + // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -57,10 +76,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep) return err } } @@ -68,7 +87,97 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { +// multiPortEndpoint is a container for TransportEndpoints which are bound to +// the same pair of address and port. +type multiPortEndpoint struct { + mu sync.RWMutex + endpointsArr []TransportEndpoint + endpointsMap map[TransportEndpoint]int + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// reciprocalScale scales a value into range [0, n). +// +// This is similar to val % n, but faster. +// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +func reciprocalScale(val, n uint32) uint32 { + return uint32((uint64(val) * uint64(n)) >> 32) +} + +// selectEndpoint calculates a hash of destination and source addresses and +// ports then uses it to select a socket. In this case, all packets from one +// address will be sent to same endpoint. +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8), + } + + h := jenkins.Sum32(ep.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + hash := h.Sum32() + + idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) + return ep.endpointsArr[idx] +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.selectEndpoint(id).HandlePacket(r, id, vv) +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) +} + +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { + ep.mu.Lock() + defer ep.mu.Unlock() + + // A new endpoint is added into endpointsArr and its index there is + // saved in endpointsMap. This will allows to remove endpoint from + // the array fast. + ep.endpointsMap[ep] = len(ep.endpointsArr) + ep.endpointsArr = append(ep.endpointsArr, t) +} + +// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { + ep.mu.Lock() + defer ep.mu.Unlock() + + idx, ok := ep.endpointsMap[t] + if !ok { + return false + } + delete(ep.endpointsMap, t) + l := len(ep.endpointsArr) + if l > 1 { + // The last endpoint in endpointsArr is moved instead of the deleted one. + lastEp := ep.endpointsArr[l-1] + ep.endpointsArr[idx] = lastEp + ep.endpointsMap[lastEp] = idx + ep.endpointsArr = ep.endpointsArr[0 : l-1] + return false + } + return true +} + +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { + if id.RemotePort != 0 { + reusePort = false + } + eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { return nil @@ -77,10 +186,29 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() + var multiPortEp *multiPortEndpoint if _, ok := eps.endpoints[id]; ok { - return tcpip.ErrPortInUse + if !reusePort { + return tcpip.ErrPortInUse + } + multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) + if !ok { + return tcpip.ErrPortInUse + } } + if reusePort { + if multiPortEp == nil { + multiPortEp = &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.seed = rand.Uint32() + eps.endpoints[id] = multiPortEp + } + + multiPortEp.singleRegisterEndpoint(ep) + + return nil + } eps.endpoints[id] = ep return nil @@ -88,12 +216,10 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.mu.Lock() - delete(eps.endpoints, id) - eps.mu.Unlock() + eps.unregisterEndpoint(id, ep) } } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index f09760180..022207081 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -107,7 +107,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) if err != nil { return err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 627786808..7d4fbe075 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -436,6 +436,10 @@ type CorkOption int // should allow reuse of local address. type ReuseAddressOption int +// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets +// to be bound to an identical socket address. +type ReusePortOption int + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index d1b9b136c..29f6c543d 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -100,7 +100,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id) + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) } // Close the receive list and drain it. @@ -541,14 +541,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) switch err { case nil: return true, nil @@ -597,7 +597,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error if commit != nil { if err := commit(); err != nil { // Unregister, the commit failed. - e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id) + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id, e) return err } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d0e1d6782..78d2c76e0 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -215,7 +215,7 @@ func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, ir n.maybeEnableSACKPermitted(rcvdSynOpts) // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { n.Close() return nil, err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d4eda50ec..5281f8be2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -162,6 +162,9 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo + // reusePort is set to true if SO_REUSEPORT is enabled. + reusePort bool + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -416,7 +419,7 @@ func (e *endpoint) Close() { e.isPortReserved = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false } } @@ -453,7 +456,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) } e.route.Release() @@ -681,6 +684,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -875,6 +884,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1057,7 +1077,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if e.id.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) if err != nil { return err } @@ -1071,13 +1091,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if sameAddr && p == e.id.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { return false, nil } id := e.id id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) { + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { case nil: e.id = id return true, nil @@ -1234,7 +1254,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { return err } @@ -1315,7 +1335,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 67e9ca0ac..b2a27a7cb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -81,6 +81,7 @@ type endpoint struct { dstPort uint16 v6only bool multicastTTL uint8 + reusePort bool // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -132,7 +133,7 @@ func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.Transport ep := newEndpoint(stack, r.NetProto, waiterQueue) // Register new endpoint so that packets are routed to it. - if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil { + if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep, ep.reusePort); err != nil { ep.Close() return nil, err } @@ -155,7 +156,7 @@ func (e *endpoint) Close() { switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } @@ -449,6 +450,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { break } } + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil } return nil } @@ -513,6 +520,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -648,7 +666,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) } e.id = id @@ -711,14 +729,14 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) } @@ -766,7 +784,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error if commit != nil { if err := commit(); err != nil { // Unregister, the commit failed. - e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id, e) e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) return err } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 58a346cd9..2a9cf4b57 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "math" "math/rand" "testing" "time" @@ -254,6 +255,90 @@ func newPayload() []byte { return b } +func TestBindPortReuse(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + var eps [5]tcpip.Endpoint + reusePortOpt := tcpip.ReusePortOption(1) + + pollChannel := make(chan tcpip.Endpoint) + for i := 0; i < len(eps); i++ { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(eps[i]) + + defer eps[i].Close() + if err := eps[i].SetSockOpt(reusePortOpt); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + } + + npackets := 100000 + nports := 10000 + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, &headers{ + srcPort: testPort + port, + dstPort: stackPort, + }) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read failed: %v", err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled + // by the same socket. + if ports[port] != ep { + t.Fatalf("Port mismatch") + } + } + } + + if len(stats) != len(eps) { + t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + } + + // Check that a packet distribution is fair between sockets. + for _, c := range stats { + n := float64(npackets) / float64(len(eps)) + // The deviation is less than 10%. + if math.Abs(float64(c)-n) > n/10 { + t.Fatal(c, n) + } + } +} + func testV4Read(c *testContext) { // Send a packet. payload := newPayload() diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ae33d14da..f0e61e083 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2163,9 +2163,13 @@ cc_binary( ":socket_test_util", "//test/util:file_descriptor", "//test/util:posix_error", + "//test/util:save_util", "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 17a46e149..0893be5a7 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -17,17 +17,24 @@ #include <string.h> #include <sys/socket.h> +#include <atomic> +#include <memory> #include <string> #include <tuple> #include <utility> #include <vector> +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -227,6 +234,238 @@ INSTANTIATE_TEST_CASE_P( TestParam{V6Loopback(), V6Loopback()}), DescribeTestParam); +using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; + +TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // Create the listening socket. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) continue; + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); + std::unique_ptr<ScopedThread> listen_thread[kThreadCount]; + int accept_counts[kThreadCount] = {}; + // TODO: figure how to not disable S/R for the whole test. + // We need to take into account that this test executes a lot of system + // calls from many threads. + DisableSave ds; + + for (int i = 0; i < kThreadCount; i++) { + listen_thread[i] = absl::make_unique<ScopedThread>( + [&listener_fds, &accept_counts, i, &connects_received]() { + do { + auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); + if (!fd.ok()) { + if (connects_received >= kConnectAttempts) { + // Another thread have shutdown our read side causing the + // accept to fail. + break; + } + ASSERT_NO_ERRNO(fd); + break; + } + // Receive some data from a socket to be sure that the connect() + // system call has been completed on another side. + int data; + EXPECT_THAT( + RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + accept_counts[i]++; + } while (++connects_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (int j = 0; j < kThreadCount; j++) { + shutdown(listener_fds[j].get(), SHUT_RDWR); + } + }); + } + + ScopedThread connecting_thread([&connector, &conn_addr]() { + for (int i = 0; i < kConnectAttempts; i++) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } + }); + + // Join threads to be sure that all connections have been counted + connecting_thread.Join(); + for (int i = 0; i < kThreadCount; i++) { + listen_thread[i]->Join(); + } + // Check that connections are distributed fairly between listening sockets + for (int i = 0; i < kThreadCount; i++) + EXPECT_THAT(accept_counts[i], + EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); +} + +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // Create the listening socket. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) continue; + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic<int> packets_received = ATOMIC_VAR_INIT(0); + std::unique_ptr<ScopedThread> receiver_thread[kThreadCount]; + int packets_per_socket[kThreadCount] = {}; + // TODO: figure how to not disable S/R for the whole test. + DisableSave ds; // Too expensive. + + for (int i = 0; i < kThreadCount; i++) { + receiver_thread[i] = absl::make_unique<ScopedThread>( + [&listener_fds, &packets_per_socket, i, &packets_received]() { + do { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + + auto ret = RetryEINTR(recvfrom)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), &addrlen); + + if (packets_received < kConnectAttempts) { + ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); + } + + if (ret != sizeof(data)) { + // Another thread may have shutdown our read side causing the + // recvfrom to fail. + break; + } + + packets_received++; + packets_per_socket[i]++; + + // A response is required to synchronize with the main thread, + // otherwise the main thread can send more than can fit into receive + // queues. + EXPECT_THAT(RetryEINTR(sendto)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); + } while (packets_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (int j = 0; j < kThreadCount; j++) + shutdown(listener_fds[j].get(), SHUT_RDWR); + }); + } + + ScopedThread main_thread([&connector, &conn_addr]() { + for (int i = 0; i < kConnectAttempts; i++) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + int data; + EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + } + }); + + main_thread.Join(); + + // Join threads to be sure that all connections have been counted + for (int i = 0; i < kThreadCount; i++) { + receiver_thread[i]->Join(); + } + // Check that packets are distributed fairly between listening sockets. + for (int i = 0; i < kThreadCount; i++) + EXPECT_THAT(packets_per_socket[i], + EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); +} + +INSTANTIATE_TEST_CASE_P( + All, SocketInetReusePortTest, + ::testing::Values( + // Listeners bound to IPv4 addresses refuse connections using IPv6 + // addresses. + TestParam{V4Any(), V4Loopback()}, + TestParam{V4Loopback(), V4MappedLoopback()}, + + // Listeners bound to IN6ADDR_ANY accept all connections. + TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()}, + + // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 + // addresses. + TestParam{V6Loopback(), V6Loopback()}), + DescribeTestParam); + struct ProtocolTestParam { std::string description; int type; @@ -806,6 +1045,56 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { } } +TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { + auto const& param = GetParam(); + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage addr = test_addr.addr; + + for (int i = 0; i < 2; i++) { + const int portreuse1 = i % 2; + auto s1 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + int fd1 = s1.get(); + socklen_t addrlen = test_addr.addr_len; + + EXPECT_THAT( + setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)), + SyscallSucceeds()); + + ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(fd1, 1), SyscallSucceeds()); + } + + // j is less than 4 to check that the port reuse logic works correctly after + // closing bound sockets. + for (int j = 0; j < 4; j++) { + const int portreuse2 = j % 2; + auto s2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + int fd2 = s2.get(); + + EXPECT_THAT( + setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)), + SyscallSucceeds()); + + LOG(INFO) << portreuse1 << " " << portreuse2; + int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen); + + // Verify that two sockets can be bound to the same port only if + // SO_REUSEPORT is set for both of them. + if (!portreuse1 || !portreuse2) + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE)); + else + ASSERT_THAT(ret, SyscallSucceeds()); + } + } +} + INSTANTIATE_TEST_CASE_P(AllFamlies, SocketMultiProtocolInetLoopbackTest, ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, ProtocolTestParam{"UDP", SOCK_DGRAM}), diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index 9ee0361ee..ec048f10f 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -118,6 +118,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. spec.Root.Readonly = false + spec.Mounts = nil // Set environment variable that indicates we are // running in gVisor and with the given platform. |