summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2018-12-28 11:26:01 -0800
committerShentubot <shentubot@google.com>2018-12-28 11:27:14 -0800
commit652d068119052b0b3bc4a0808a4400a22380a30b (patch)
treef5a617063151ffb9563ebbcd3189611e854952db /pkg/tcpip/transport
parenta3217b71723a93abb7a2aca535408ab84d81ac2f (diff)
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port. Incoming packets are distributed to sockets using a hash based on source and destination addresses. This means that all packets from one sender will be received by the same server socket. PiperOrigin-RevId: 227153413 Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go34
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go85
5 files changed, 141 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index d1b9b136c..29f6c543d 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -100,7 +100,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
}
// Close the receive list and drain it.
@@ -541,14 +541,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
switch err {
case nil:
return true, nil
@@ -597,7 +597,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
- e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id)
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id, e)
return err
}
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d0e1d6782..78d2c76e0 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -215,7 +215,7 @@ func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, ir
n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
n.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d4eda50ec..5281f8be2 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -162,6 +162,9 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
+ // reusePort is set to true if SO_REUSEPORT is enabled.
+ reusePort bool
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -416,7 +419,7 @@ func (e *endpoint) Close() {
e.isPortReserved = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.isRegistered = false
}
}
@@ -453,7 +456,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
}
e.route.Release()
@@ -681,6 +684,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
case tcpip.QuickAckOption:
if v == 0 {
atomic.StoreUint32(&e.slowAck, 1)
@@ -875,6 +884,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1057,7 +1077,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
if err != nil {
return err
}
@@ -1071,13 +1091,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) {
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
case nil:
e.id = id
return true, nil
@@ -1234,7 +1254,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
return err
}
@@ -1315,7 +1335,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 67e9ca0ac..b2a27a7cb 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -81,6 +81,7 @@ type endpoint struct {
dstPort uint16
v6only bool
multicastTTL uint8
+ reusePort bool
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -132,7 +133,7 @@ func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.Transport
ep := newEndpoint(stack, r.NetProto, waiterQueue)
// Register new endpoint so that packets are routed to it.
- if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
+ if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep, ep.reusePort); err != nil {
ep.Close()
return nil, err
}
@@ -155,7 +156,7 @@ func (e *endpoint) Close() {
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
@@ -449,6 +450,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
break
}
}
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
}
return nil
}
@@ -513,6 +520,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -648,7 +666,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
}
e.id = id
@@ -711,14 +729,14 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
}
@@ -766,7 +784,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
- e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id, e)
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
return err
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 58a346cd9..2a9cf4b57 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "math"
"math/rand"
"testing"
"time"
@@ -254,6 +255,90 @@ func newPayload() []byte {
return b
}
+func TestBindPortReuse(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ var eps [5]tcpip.Endpoint
+ reusePortOpt := tcpip.ReusePortOption(1)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i := 0; i < len(eps); i++ {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(eps[i])
+
+ defer eps[i].Close()
+ if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, nil); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload, &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort,
+ })
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v", err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled
+ // by the same socket.
+ if ports[port] != ep {
+ t.Fatalf("Port mismatch")
+ }
+ }
+ }
+
+ if len(stats) != len(eps) {
+ t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ }
+
+ // Check that a packet distribution is fair between sockets.
+ for _, c := range stats {
+ n := float64(npackets) / float64(len(eps))
+ // The deviation is less than 10%.
+ if math.Abs(float64(c)-n) > n/10 {
+ t.Fatal(c, n)
+ }
+ }
+}
+
func testV4Read(c *testContext) {
// Send a packet.
payload := newPayload()