diff options
author | Andrei Vagin <avagin@google.com> | 2018-12-28 11:26:01 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-28 11:27:14 -0800 |
commit | 652d068119052b0b3bc4a0808a4400a22380a30b (patch) | |
tree | f5a617063151ffb9563ebbcd3189611e854952db /pkg/tcpip/stack/transport_demuxer.go | |
parent | a3217b71723a93abb7a2aca535408ab84d81ac2f (diff) |
Implement SO_REUSEPORT for TCP and UDP sockets
This option allows multiple sockets to be bound to the same port.
Incoming packets are distributed to sockets using a hash based on source and
destination addresses. This means that all packets from one sender will be
received by the same server socket.
PiperOrigin-RevId: 227153413
Change-Id: I59b6edda9c2209d5b8968671e9129adb675920cf
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 144 |
1 files changed, 135 insertions, 9 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c8522ad9e..a5ff2159a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,10 +15,12 @@ package stack import ( + "math/rand" "sync" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" ) @@ -34,6 +36,23 @@ type transportEndpoints struct { endpoints map[TransportEndpointID]TransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { + eps.mu.Lock() + defer eps.mu.Unlock() + e, ok := eps.endpoints[id] + if !ok { + return + } + if multiPortEp, ok := e.(*multiPortEndpoint); ok { + if !multiPortEp.unregisterEndpoint(ep) { + return + } + } + delete(eps.endpoints, id) +} + // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -57,10 +76,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep) return err } } @@ -68,7 +87,97 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { +// multiPortEndpoint is a container for TransportEndpoints which are bound to +// the same pair of address and port. +type multiPortEndpoint struct { + mu sync.RWMutex + endpointsArr []TransportEndpoint + endpointsMap map[TransportEndpoint]int + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// reciprocalScale scales a value into range [0, n). +// +// This is similar to val % n, but faster. +// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +func reciprocalScale(val, n uint32) uint32 { + return uint32((uint64(val) * uint64(n)) >> 32) +} + +// selectEndpoint calculates a hash of destination and source addresses and +// ports then uses it to select a socket. In this case, all packets from one +// address will be sent to same endpoint. +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8), + } + + h := jenkins.Sum32(ep.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + hash := h.Sum32() + + idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) + return ep.endpointsArr[idx] +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.selectEndpoint(id).HandlePacket(r, id, vv) +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) +} + +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { + ep.mu.Lock() + defer ep.mu.Unlock() + + // A new endpoint is added into endpointsArr and its index there is + // saved in endpointsMap. This will allows to remove endpoint from + // the array fast. + ep.endpointsMap[ep] = len(ep.endpointsArr) + ep.endpointsArr = append(ep.endpointsArr, t) +} + +// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { + ep.mu.Lock() + defer ep.mu.Unlock() + + idx, ok := ep.endpointsMap[t] + if !ok { + return false + } + delete(ep.endpointsMap, t) + l := len(ep.endpointsArr) + if l > 1 { + // The last endpoint in endpointsArr is moved instead of the deleted one. + lastEp := ep.endpointsArr[l-1] + ep.endpointsArr[idx] = lastEp + ep.endpointsMap[lastEp] = idx + ep.endpointsArr = ep.endpointsArr[0 : l-1] + return false + } + return true +} + +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { + if id.RemotePort != 0 { + reusePort = false + } + eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { return nil @@ -77,10 +186,29 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() + var multiPortEp *multiPortEndpoint if _, ok := eps.endpoints[id]; ok { - return tcpip.ErrPortInUse + if !reusePort { + return tcpip.ErrPortInUse + } + multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) + if !ok { + return tcpip.ErrPortInUse + } } + if reusePort { + if multiPortEp == nil { + multiPortEp = &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.seed = rand.Uint32() + eps.endpoints[id] = multiPortEp + } + + multiPortEp.singleRegisterEndpoint(ep) + + return nil + } eps.endpoints[id] = ep return nil @@ -88,12 +216,10 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.mu.Lock() - delete(eps.endpoints, id) - eps.mu.Unlock() + eps.unregisterEndpoint(id, ep) } } } |