summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2018-12-28 11:26:01 -0800
committerShentubot <shentubot@google.com>2018-12-28 11:27:14 -0800
commit652d068119052b0b3bc4a0808a4400a22380a30b (patch)
treef5a617063151ffb9563ebbcd3189611e854952db /pkg/tcpip/stack
parenta3217b71723a93abb7a2aca535408ab84d81ac2f (diff)
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port. Incoming packets are distributed to sockets using a hash based on source and destination addresses. This means that all packets from one sender will be received by the same server socket. PiperOrigin-RevId: 227153413 Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go144
-rw-r--r--pkg/tcpip/stack/transport_test.go2
4 files changed, 143 insertions, 16 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 90cc05cda..9ff1c8731 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -22,6 +22,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0ac116675..7aa9dbd46 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -883,9 +883,9 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep)
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
}
s.mu.RLock()
@@ -896,14 +896,14 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N
return tcpip.ErrUnknownNICID
}
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep)
+ return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id)
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
return
}
@@ -912,7 +912,7 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
nic := s.nics[nicID]
if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id)
+ nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index c8522ad9e..a5ff2159a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,10 +15,12 @@
package stack
import (
+ "math/rand"
"sync"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
)
@@ -34,6 +36,23 @@ type transportEndpoints struct {
endpoints map[TransportEndpointID]TransportEndpoint
}
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ e, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if multiPortEp, ok := e.(*multiPortEndpoint); ok {
+ if !multiPortEp.unregisterEndpoint(ep) {
+ return
+ }
+ }
+ delete(eps.endpoints, id)
+}
+
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
@@ -57,10 +76,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
return err
}
}
@@ -68,7 +87,97 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port.
+type multiPortEndpoint struct {
+ mu sync.RWMutex
+ endpointsArr []TransportEndpoint
+ endpointsMap map[TransportEndpoint]int
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+ return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8),
+ }
+
+ h := jenkins.Sum32(ep.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+ hash := h.Sum32()
+
+ idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
+ return ep.endpointsArr[idx]
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+}
+
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // A new endpoint is added into endpointsArr and its index there is
+ // saved in endpointsMap. This will allows to remove endpoint from
+ // the array fast.
+ ep.endpointsMap[ep] = len(ep.endpointsArr)
+ ep.endpointsArr = append(ep.endpointsArr, t)
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ idx, ok := ep.endpointsMap[t]
+ if !ok {
+ return false
+ }
+ delete(ep.endpointsMap, t)
+ l := len(ep.endpointsArr)
+ if l > 1 {
+ // The last endpoint in endpointsArr is moved instead of the deleted one.
+ lastEp := ep.endpointsArr[l-1]
+ ep.endpointsArr[idx] = lastEp
+ ep.endpointsMap[lastEp] = idx
+ ep.endpointsArr = ep.endpointsArr[0 : l-1]
+ return false
+ }
+ return true
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if id.RemotePort != 0 {
+ reusePort = false
+ }
+
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return nil
@@ -77,10 +186,29 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
+ var multiPortEp *multiPortEndpoint
if _, ok := eps.endpoints[id]; ok {
- return tcpip.ErrPortInUse
+ if !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
+ if !ok {
+ return tcpip.ErrPortInUse
+ }
}
+ if reusePort {
+ if multiPortEp == nil {
+ multiPortEp = &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.seed = rand.Uint32()
+ eps.endpoints[id] = multiPortEp
+ }
+
+ multiPortEp.singleRegisterEndpoint(ep)
+
+ return nil
+ }
eps.endpoints[id] = ep
return nil
@@ -88,12 +216,10 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.mu.Lock()
- delete(eps.endpoints, id)
- eps.mu.Unlock()
+ eps.unregisterEndpoint(id, ep)
}
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f09760180..022207081 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -107,7 +107,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
if err != nil {
return err
}