diff options
author | Andrei Vagin <avagin@google.com> | 2018-12-28 11:26:01 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-28 11:27:14 -0800 |
commit | 652d068119052b0b3bc4a0808a4400a22380a30b (patch) | |
tree | f5a617063151ffb9563ebbcd3189611e854952db /pkg/tcpip/transport | |
parent | a3217b71723a93abb7a2aca535408ab84d81ac2f (diff) |
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port.
Incoming packets are distributed to sockets using a hash based on source and
destination addresses. This means that all packets from one sender will be
received by the same server socket.
PiperOrigin-RevId: 227153413
Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 85 |
5 files changed, 141 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index d1b9b136c..29f6c543d 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -100,7 +100,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id) + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) } // Close the receive list and drain it. @@ -541,14 +541,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) switch err { case nil: return true, nil @@ -597,7 +597,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error if commit != nil { if err := commit(); err != nil { // Unregister, the commit failed. - e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id) + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id, e) return err } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d0e1d6782..78d2c76e0 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -215,7 +215,7 @@ func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, ir n.maybeEnableSACKPermitted(rcvdSynOpts) // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { n.Close() return nil, err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d4eda50ec..5281f8be2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -162,6 +162,9 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo + // reusePort is set to true if SO_REUSEPORT is enabled. + reusePort bool + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -416,7 +419,7 @@ func (e *endpoint) Close() { e.isPortReserved = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false } } @@ -453,7 +456,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) } e.route.Release() @@ -681,6 +684,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -875,6 +884,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1057,7 +1077,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if e.id.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) if err != nil { return err } @@ -1071,13 +1091,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if sameAddr && p == e.id.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { return false, nil } id := e.id id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) { + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { case nil: e.id = id return true, nil @@ -1234,7 +1254,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { return err } @@ -1315,7 +1335,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 67e9ca0ac..b2a27a7cb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -81,6 +81,7 @@ type endpoint struct { dstPort uint16 v6only bool multicastTTL uint8 + reusePort bool // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -132,7 +133,7 @@ func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.Transport ep := newEndpoint(stack, r.NetProto, waiterQueue) // Register new endpoint so that packets are routed to it. - if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil { + if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep, ep.reusePort); err != nil { ep.Close() return nil, err } @@ -155,7 +156,7 @@ func (e *endpoint) Close() { switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } @@ -449,6 +450,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { break } } + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil } return nil } @@ -513,6 +520,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -648,7 +666,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) } e.id = id @@ -711,14 +729,14 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) } @@ -766,7 +784,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error if commit != nil { if err := commit(); err != nil { // Unregister, the commit failed. - e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id, e) e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) return err } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 58a346cd9..2a9cf4b57 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "math" "math/rand" "testing" "time" @@ -254,6 +255,90 @@ func newPayload() []byte { return b } +func TestBindPortReuse(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + var eps [5]tcpip.Endpoint + reusePortOpt := tcpip.ReusePortOption(1) + + pollChannel := make(chan tcpip.Endpoint) + for i := 0; i < len(eps); i++ { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(eps[i]) + + defer eps[i].Close() + if err := eps[i].SetSockOpt(reusePortOpt); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + } + + npackets := 100000 + nports := 10000 + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, &headers{ + srcPort: testPort + port, + dstPort: stackPort, + }) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read failed: %v", err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled + // by the same socket. + if ports[port] != ep { + t.Fatalf("Port mismatch") + } + } + } + + if len(stats) != len(eps) { + t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + } + + // Check that a packet distribution is fair between sockets. + for _, c := range stats { + n := float64(npackets) / float64(len(eps)) + // The deviation is less than 10%. + if math.Abs(float64(c)-n) > n/10 { + t.Fatal(c, n) + } + } +} + func testV4Read(c *testContext) { // Send a packet. payload := newPayload() |