summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go18
-rw-r--r--pkg/tcpip/stack/registration.go14
-rw-r--r--pkg/tcpip/stack/route.go12
-rw-r--r--pkg/tcpip/stack/stack.go24
-rw-r--r--pkg/tcpip/stack/stack_test.go40
-rw-r--r--pkg/tcpip/stack/transport_test.go2
6 files changed, 75 insertions, 35 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 79f845225..14267bb48 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -28,10 +28,11 @@ import (
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ loopback bool
demux *transportDemuxer
@@ -62,12 +63,13 @@ const (
NeverPrimaryEndpoint
)
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
return &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
+ loopback: loopback,
demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
@@ -407,7 +409,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.mu.RLock()
for _, ref := range n.endpoints {
if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* multicastLoop */)
r.RemoteLinkAddress = remote
ref.ep.HandlePacket(&r, vv)
ref.decRef()
@@ -418,7 +420,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* multicastLoop */)
r.RemoteLinkAddress = remote
ref.ep.HandlePacket(&r, vv)
ref.decRef()
@@ -430,7 +432,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
//
// TODO: Should we be forwarding the packet even if promiscuous?
if n.stack.Forwarding() {
- r, err := n.stack.FindRoute(0, "", dst, protocol)
+ r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidAddressesReceived.Increment()
return
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 62acd5919..cf4d52fe9 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -125,6 +125,18 @@ type TransportDispatcher interface {
DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
+// PacketLooping specifies where an outbound packet should be sent.
+type PacketLooping byte
+
+const (
+ // PacketOut indicates that the packet should be passed to the link
+ // endpoint.
+ PacketOut PacketLooping = 1 << iota
+
+ // PacketLoop indicates that the packet should be handled locally.
+ PacketLoop
+)
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
@@ -149,7 +161,7 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
+ WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 2b4185014..c9603ad5e 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -46,17 +46,20 @@ type Route struct {
// ref a reference to the network endpoint through which the route
// starts.
ref *referencedNetworkEndpoint
+
+ multicastLoop bool
}
// makeRoute initializes a new route. It takes ownership of the provided
// reference to a network endpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, multicastLoop bool) Route {
return Route{
NetProto: netProto,
LocalAddress: localAddr,
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
+ multicastLoop: multicastLoop,
}
}
@@ -134,7 +137,12 @@ func (r *Route) IsResolutionRequired() bool {
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
- err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
+ loop := PacketOut
+ if r.multicastLoop && (header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress)) {
+ loop |= PacketLoop
+ }
+
+ err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl, loop)
if err == tcpip.ErrNoRoute {
r.Stats().IP.OutgoingPacketErrors.Increment()
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cfda7ec3c..047b704e0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -513,7 +513,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
ep := FindLinkEndpoint(linkEP)
if ep == nil {
return tcpip.ErrBadLinkEndpoint
@@ -527,7 +527,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, name, ep)
+ n := newNIC(s, id, name, ep, loopback)
s.nics[id] = n
if enabled {
@@ -539,26 +539,32 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, true)
+ return s.createNIC(id, "", linkEP, true, false)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true)
+ return s.createNIC(id, name, linkEP, true, false)
+}
+
+// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
+// endpoint, and a human-readable name.
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, false)
+ return s.createNIC(id, "", linkEP, false, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, false)
+ return s.createNIC(id, name, linkEP, false, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -748,7 +754,7 @@ func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.Netwo
// FindRoute creates a route to the given destination address, leaving through
// the given nic and local address (if provided).
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -758,7 +764,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok {
if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref), nil
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, multicastLoop && !nic.loopback), nil
}
}
} else {
@@ -774,7 +780,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
remoteAddr = ref.ep.ID().LocalAddress
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref)
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, multicastLoop && !nic.loopback)
if needRoute {
r.NextHop = route.Gateway
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index aba1e984c..b366de21d 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -112,7 +112,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.linkEP.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -122,6 +122,18 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(protocol)
+
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ f.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber)
}
@@ -262,7 +274,7 @@ func TestNetworkReceive(t *testing.T) {
}
func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) {
- r, err := s.FindRoute(0, "", addr, fakeNetNumber)
+ r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -354,7 +366,7 @@ func TestNetworkSendMultiRoute(t *testing.T) {
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
- r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -371,7 +383,7 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr,
}
func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
- _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != tcpip.ErrNoRoute {
t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
}
@@ -514,7 +526,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
}
// Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -584,7 +596,7 @@ func TestPromiscuousMode(t *testing.T) {
}
// Check that we can't get a route as there is no local address.
- _, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
if err != tcpip.ErrNoRoute {
t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
}
@@ -622,7 +634,7 @@ func TestAddressSpoofing(t *testing.T) {
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
@@ -632,7 +644,7 @@ func TestAddressSpoofing(t *testing.T) {
if err := s.SetSpoofing(1, true); err != nil {
t.Fatalf("SetSpoofing failed: %v", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -654,14 +666,14 @@ func TestBroadcastNeedsNoRoute(t *testing.T) {
s.SetRouteTable([]tcpip.Route{})
// If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber); err != tcpip.ErrNetworkUnreachable {
+ if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil {
t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err)
}
- r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber)
+ r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -675,7 +687,7 @@ func TestBroadcastNeedsNoRoute(t *testing.T) {
}
// If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber); err != tcpip.ErrNetworkUnreachable {
+ if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
}
@@ -738,7 +750,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
// If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber); err != want {
+ if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
@@ -746,7 +758,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
}
- if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber); tc.routeNeeded {
+ if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
// Route table is empty but we need a route, this should cause an error.
if err != tcpip.ErrNoRoute {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
@@ -763,7 +775,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
// If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber); err != want {
+ if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a9e844e3d..279ab3c56 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -103,7 +103,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber)
+ r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return tcpip.ErrNoRoute
}