summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/transport_demuxer.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go420
1 files changed, 420 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
new file mode 100644
index 000000000..605bfadeb
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -0,0 +1,420 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math/rand"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+type protocolIDs struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+}
+
+// transportEndpoints manages all endpoints of a given protocol. It has its own
+// mutex so as to reduce interference between protocols.
+type transportEndpoints struct {
+ // mu protects all fields of the transportEndpoints.
+ mu sync.RWMutex
+ endpoints map[TransportEndpointID]TransportEndpoint
+ // rawEndpoints contains endpoints for raw sockets, which receive all
+ // traffic of a given protocol regardless of port.
+ rawEndpoints []RawTransportEndpoint
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ e, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if multiPortEp, ok := e.(*multiPortEndpoint); ok {
+ if !multiPortEp.unregisterEndpoint(ep) {
+ return
+ }
+ }
+ delete(eps.endpoints, id)
+}
+
+// transportDemuxer demultiplexes packets targeted at a transport endpoint
+// (i.e., after they've been parsed by the network layer). It does two levels
+// of demultiplexing: first based on the network and transport protocols, then
+// based on endpoints IDs. It should only be instantiated via
+// newTransportDemuxer.
+type transportDemuxer struct {
+ // protocol is immutable.
+ protocol map[protocolIDs]*transportEndpoints
+}
+
+func newTransportDemuxer(stack *Stack) *transportDemuxer {
+ d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+
+ // Add each network and transport pair to the demuxer.
+ for netProto := range stack.networkProtocols {
+ for proto := range stack.transportProtocols {
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ }
+ }
+ }
+
+ return d
+}
+
+// registerEndpoint registers the given endpoint with the dispatcher such that
+// packets that match the endpoint ID are delivered to it.
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// multiPortEndpoint is a container for TransportEndpoints which are bound to
+// the same pair of address and port.
+type multiPortEndpoint struct {
+ mu sync.RWMutex
+ endpointsArr []TransportEndpoint
+ endpointsMap map[TransportEndpoint]int
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// reciprocalScale scales a value into range [0, n).
+//
+// This is similar to val % n, but faster.
+// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
+func reciprocalScale(val, n uint32) uint32 {
+ return uint32((uint64(val) * uint64(n)) >> 32)
+}
+
+// selectEndpoint calculates a hash of destination and source addresses and
+// ports then uses it to select a socket. In this case, all packets from one
+// address will be sent to same endpoint.
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8),
+ }
+
+ h := jenkins.Sum32(ep.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+ hash := h.Sum32()
+
+ idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
+ return ep.endpointsArr[idx]
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints managed by ep.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
+ }
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ }
+ } else {
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+}
+
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // A new endpoint is added into endpointsArr and its index there is
+ // saved in endpointsMap. This will allows to remove endpoint from
+ // the array fast.
+ ep.endpointsMap[t] = len(ep.endpointsArr)
+ ep.endpointsArr = append(ep.endpointsArr, t)
+}
+
+// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ idx, ok := ep.endpointsMap[t]
+ if !ok {
+ return false
+ }
+ delete(ep.endpointsMap, t)
+ l := len(ep.endpointsArr)
+ if l > 1 {
+ // The last endpoint in endpointsArr is moved instead of the deleted one.
+ lastEp := ep.endpointsArr[l-1]
+ ep.endpointsArr[idx] = lastEp
+ ep.endpointsMap[lastEp] = idx
+ ep.endpointsArr = ep.endpointsArr[0 : l-1]
+ return false
+ }
+ return true
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+ if id.RemotePort != 0 {
+ reusePort = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+
+ var multiPortEp *multiPortEndpoint
+ if _, ok := eps.endpoints[id]; ok {
+ if !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
+ if !ok {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if reusePort {
+ if multiPortEp == nil {
+ multiPortEp = &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.seed = rand.Uint32()
+ eps.endpoints[id] = multiPortEp
+ }
+
+ multiPortEp.singleRegisterEndpoint(ep)
+
+ return nil
+ }
+ eps.endpoints[id] = ep
+
+ return nil
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.unregisterEndpoint(id, ep)
+ }
+ }
+}
+
+var loopbackSubnet = func() tcpip.Subnet {
+ sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ panic(err)
+ }
+ return sn
+}()
+
+// deliverPacket attempts to find one or more matching transport endpoints, and
+// then, if matches are found, delivers the packet to them. Returns true if it
+// found one or more endpoints, false otherwise.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // If a sender bound to the Loopback interface sends a broadcast,
+ // that broadcast must not be delivered to the sender.
+ if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
+ return false
+ }
+
+ // If the packet is a broadcast, then find all matching transport endpoints.
+ // Otherwise, try to find a single matching transport endpoint.
+ destEps := make([]TransportEndpoint, 0, 1)
+ eps.mu.RLock()
+
+ if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
+ for epID, endpoint := range eps.endpoints {
+ if epID.LocalPort == id.LocalPort {
+ destEps = append(destEps, endpoint)
+ }
+ }
+ } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
+
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEps) == 0 {
+ // UDP packet could not be delivered to an unknown destination port.
+ if protocol == header.UDPProtocolNumber {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ }
+ return false
+ }
+
+ // Deliver the packet.
+ for _, ep := range destEps {
+ ep.HandlePacket(r, id, vv)
+ }
+
+ return true
+}
+
+// deliverRawPacket attempts to deliver the given packet and returns whether it
+// was delivered successfully.
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
+ // raw endpoint first. If there are multiple raw endpoints, they all
+ // receive the packet.
+ foundRaw := false
+ eps.mu.RLock()
+ for _, rawEP := range eps.rawEndpoints {
+ // Each endpoint gets its own copy of the packet for the sake
+ // of save/restore.
+ rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ foundRaw = true
+ }
+ eps.mu.RUnlock()
+
+ return foundRaw
+}
+
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandleControlPacket(id, typ, extra, vv)
+
+ return true
+}
+
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ return ep
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ return ep
+ }
+
+ return nil
+}
+
+// registerRawEndpoint registers the given endpoint with the dispatcher such
+// that packets of the appropriate protocol are delivered to it. A single
+// packet can be sent to one or more raw endpoints along with a non-raw
+// endpoint.
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ eps.rawEndpoints = append(eps.rawEndpoints, ep)
+
+ return nil
+}
+
+// unregisterRawEndpoint unregisters the raw endpoint for the given transport
+// protocol such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
+ }
+ }
+}