summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/registration.go13
-rw-r--r--pkg/tcpip/stack/stack.go20
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go67
-rw-r--r--pkg/tcpip/stack/transport_test.go2
4 files changed, 53 insertions, 49 deletions
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ff356ea22..f3cc849ec 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -64,13 +64,24 @@ const (
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView)
+ HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
+// RawTransportEndpoint is the interface that needs to be implemented by raw
+// transport protocol endpoints. RawTransportEndpoints receive the entire
+// packet - including the link, network, and transport headers - as delivered
+// to netstack.
+type RawTransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint. The packet contains all data from the link
+ // layer up.
+ HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+}
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 15a268b10..a74c0a7a0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -955,11 +955,11 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
-// transport dispatcher. Received packets that match the provided protocol will
-// be delivered to the given endpoint.
-func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+// transport dispatcher. Received packets that match the provided transport
+// protocol will be delivered to the given endpoint.
+func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
if nicID == 0 {
- return s.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
s.mu.RLock()
@@ -970,14 +970,14 @@ func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpi
return tcpip.ErrUnknownNICID
}
- return nic.demux.registerRawEndpoint(netProtos, protocol, ep, reusePort)
+ return nic.demux.registerRawEndpoint(netProto, transProto, ep)
}
-// UnregisterRawTransportEndpoint removes the endpoint for the protocol from
-// the stack transport dispatcher.
-func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
+// UnregisterRawTransportEndpoint removes the endpoint for the transport
+// protocol from the stack transport dispatcher.
+func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
return
}
@@ -986,7 +986,7 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProtos []tc
nic := s.nics[nicID]
if nic != nil {
- nic.demux.unregisterRawEndpoint(netProtos, protocol, ep)
+ nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9ab314188..a8ac18e72 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"math/rand"
"sync"
@@ -37,7 +38,7 @@ type transportEndpoints struct {
endpoints map[TransportEndpointID]TransportEndpoint
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
- rawEndpoints []TransportEndpoint
+ rawEndpoints []RawTransportEndpoint
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -60,8 +61,10 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
-// based on endpoints IDs.
+// based on endpoints IDs. It should only be instantiated via
+// newTransportDemuxer.
type transportDemuxer struct {
+ // protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
}
@@ -137,22 +140,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
// If this is a broadcast datagram, deliver the datagram to all endpoints
// managed by ep.
if id.LocalAddress == header.IPv4Broadcast {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, netHeader, vv)
+ endpoint.HandlePacket(r, id, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vvCopy.ToVectorisedView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
} else {
- ep.selectEndpoint(id).HandlePacket(r, id, netHeader, vv)
+ ep.selectEndpoint(id).HandlePacket(r, id, vv)
}
}
@@ -286,17 +289,17 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
// raw endpoint first. If there are multipe raw endpoints, they all
// receive the packet.
- found := false
+ foundRaw := false
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, id, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
- found = true
+ rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ foundRaw = true
}
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 && !found {
+ if len(destEps) == 0 && !foundRaw {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
@@ -306,7 +309,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, netHeader, vv)
+ ep.HandlePacket(r, id, vv)
}
return true
@@ -371,19 +374,8 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
// endpoint.
-func (d *transportDemuxer) registerRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- for i, n := range netProtos {
- if err := d.singleRegisterRawEndpoint(n, protocol, ep); err != nil {
- d.unregisterRawEndpoint(netProtos[:i], protocol, ep)
- return err
- }
- }
-
- return nil
-}
-
-func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) *tcpip.Error {
- eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
}
@@ -395,19 +387,20 @@ func (d *transportDemuxer) singleRegisterRawEndpoint(netProto tcpip.NetworkProto
return nil
}
-// unregisterRawEndpoint unregisters the raw endpoint for the given protocol
-// such that it won't receive any more packets.
-func (d *transportDemuxer) unregisterRawEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, ep TransportEndpoint) {
- for _, n := range netProtos {
- if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.mu.Lock()
- defer eps.mu.Unlock()
- for i, rawEP := range eps.rawEndpoints {
- if rawEP == ep {
- eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
- return
- }
- }
+// unregisterRawEndpoint unregisters the raw endpoint for the given transport
+// protocol such that it won't receive any more packets.
+func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ for i, rawEP := range eps.rawEndpoints {
+ if rawEP == ep {
+ eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
+ return
}
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index dfd31557a..0c2589083 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -168,7 +168,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.View, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {