summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/ip_test.go8
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go15
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go15
-rw-r--r--pkg/tcpip/stack/nic.go18
-rw-r--r--pkg/tcpip/stack/registration.go14
-rw-r--r--pkg/tcpip/stack/route.go12
-rw-r--r--pkg/tcpip/stack/stack.go24
-rw-r--r--pkg/tcpip/stack/stack_test.go40
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/tcpip.go5
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
19 files changed, 155 insertions, 57 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ed39640c1..5ab542f2c 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 97a43aece..7eb0e697d 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -177,7 +177,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv4.ProtocolNumber)
+ return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
@@ -191,7 +191,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv6.ProtocolNumber)
+ return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
func TestIPv4Send(t *testing.T) {
@@ -221,7 +221,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil {
+ if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -450,7 +450,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil {
+ if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index bfc3c08fa..545684032 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -104,7 +104,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + payload.Size())
id := uint32(0)
@@ -123,8 +123,19 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- r.Stats().IP.PacketsSent.Increment()
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 797176243..15574bab1 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -161,7 +161,7 @@ func (c *testContext) cleanup() {
func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber)
+ r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatal(err)
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 5f68ef7d5..df3b64c98 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -84,7 +84,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
length := uint16(hdr.UsedLength() + payload.Size())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -94,8 +94,19 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- r.Stats().IP.PacketsSent.Increment()
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ e.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ r.Stats().IP.PacketsSent.Increment()
return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 79f845225..14267bb48 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -28,10 +28,11 @@ import (
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ loopback bool
demux *transportDemuxer
@@ -62,12 +63,13 @@ const (
NeverPrimaryEndpoint
)
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
return &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
+ loopback: loopback,
demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
@@ -407,7 +409,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.mu.RLock()
for _, ref := range n.endpoints {
if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* multicastLoop */)
r.RemoteLinkAddress = remote
ref.ep.HandlePacket(&r, vv)
ref.decRef()
@@ -418,7 +420,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
+ r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* multicastLoop */)
r.RemoteLinkAddress = remote
ref.ep.HandlePacket(&r, vv)
ref.decRef()
@@ -430,7 +432,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
//
// TODO: Should we be forwarding the packet even if promiscuous?
if n.stack.Forwarding() {
- r, err := n.stack.FindRoute(0, "", dst, protocol)
+ r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidAddressesReceived.Increment()
return
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 62acd5919..cf4d52fe9 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -125,6 +125,18 @@ type TransportDispatcher interface {
DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
+// PacketLooping specifies where an outbound packet should be sent.
+type PacketLooping byte
+
+const (
+ // PacketOut indicates that the packet should be passed to the link
+ // endpoint.
+ PacketOut PacketLooping = 1 << iota
+
+ // PacketLoop indicates that the packet should be handled locally.
+ PacketLoop
+)
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
@@ -149,7 +161,7 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
+ WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 2b4185014..c9603ad5e 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -46,17 +46,20 @@ type Route struct {
// ref a reference to the network endpoint through which the route
// starts.
ref *referencedNetworkEndpoint
+
+ multicastLoop bool
}
// makeRoute initializes a new route. It takes ownership of the provided
// reference to a network endpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, multicastLoop bool) Route {
return Route{
NetProto: netProto,
LocalAddress: localAddr,
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
+ multicastLoop: multicastLoop,
}
}
@@ -134,7 +137,12 @@ func (r *Route) IsResolutionRequired() bool {
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
- err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
+ loop := PacketOut
+ if r.multicastLoop && (header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress)) {
+ loop |= PacketLoop
+ }
+
+ err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl, loop)
if err == tcpip.ErrNoRoute {
r.Stats().IP.OutgoingPacketErrors.Increment()
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cfda7ec3c..047b704e0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -513,7 +513,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
ep := FindLinkEndpoint(linkEP)
if ep == nil {
return tcpip.ErrBadLinkEndpoint
@@ -527,7 +527,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, name, ep)
+ n := newNIC(s, id, name, ep, loopback)
s.nics[id] = n
if enabled {
@@ -539,26 +539,32 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, true)
+ return s.createNIC(id, "", linkEP, true, false)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true)
+ return s.createNIC(id, name, linkEP, true, false)
+}
+
+// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
+// endpoint, and a human-readable name.
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, false)
+ return s.createNIC(id, "", linkEP, false, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, false)
+ return s.createNIC(id, name, linkEP, false, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -748,7 +754,7 @@ func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.Netwo
// FindRoute creates a route to the given destination address, leaving through
// the given nic and local address (if provided).
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -758,7 +764,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok {
if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref), nil
+ return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, multicastLoop && !nic.loopback), nil
}
}
} else {
@@ -774,7 +780,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
remoteAddr = ref.ep.ID().LocalAddress
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref)
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, multicastLoop && !nic.loopback)
if needRoute {
r.NextHop = route.Gateway
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index aba1e984c..b366de21d 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -112,7 +112,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.linkEP.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -122,6 +122,18 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(protocol)
+
+ if loop&stack.PacketLoop != 0 {
+ views := make([]buffer.View, 1, 1+len(payload.Views()))
+ views[0] = hdr.View()
+ views = append(views, payload.Views()...)
+ vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ f.HandlePacket(r, vv)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber)
}
@@ -262,7 +274,7 @@ func TestNetworkReceive(t *testing.T) {
}
func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) {
- r, err := s.FindRoute(0, "", addr, fakeNetNumber)
+ r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -354,7 +366,7 @@ func TestNetworkSendMultiRoute(t *testing.T) {
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
- r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -371,7 +383,7 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr,
}
func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
- _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != tcpip.ErrNoRoute {
t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
}
@@ -514,7 +526,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
}
// Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -584,7 +596,7 @@ func TestPromiscuousMode(t *testing.T) {
}
// Check that we can't get a route as there is no local address.
- _, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
if err != tcpip.ErrNoRoute {
t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
}
@@ -622,7 +634,7 @@ func TestAddressSpoofing(t *testing.T) {
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
@@ -632,7 +644,7 @@ func TestAddressSpoofing(t *testing.T) {
if err := s.SetSpoofing(1, true); err != nil {
t.Fatalf("SetSpoofing failed: %v", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
@@ -654,14 +666,14 @@ func TestBroadcastNeedsNoRoute(t *testing.T) {
s.SetRouteTable([]tcpip.Route{})
// If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber); err != tcpip.ErrNetworkUnreachable {
+ if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil {
t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err)
}
- r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber)
+ r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -675,7 +687,7 @@ func TestBroadcastNeedsNoRoute(t *testing.T) {
}
// If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber); err != tcpip.ErrNetworkUnreachable {
+ if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
}
@@ -738,7 +750,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
// If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber); err != want {
+ if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
@@ -746,7 +758,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
}
- if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber); tc.routeNeeded {
+ if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
// Route table is empty but we need a route, this should cause an error.
if err != tcpip.ErrNoRoute {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
@@ -763,7 +775,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
// If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber); err != want {
+ if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a9e844e3d..279ab3c56 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -103,7 +103,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber)
+ r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return tcpip.ErrNoRoute
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 7010d1b68..825854148 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -68,6 +68,7 @@ func (e *Error) IgnoreStats() bool {
var (
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
ErrUnknownNICID = &Error{msg: "unknown nic id"}
+ ErrUnknownDevice = &Error{msg: "unknown device"}
ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
ErrDuplicateAddress = &Error{msg: "duplicate address"}
@@ -477,6 +478,10 @@ type MulticastInterfaceOption struct {
InterfaceAddr Address
}
+// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
+// multicast packets sent over a non-loopback interface will be looped back.
+type MulticastLoopOption bool
+
// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
// AddMembershipOption and RemoveMembershipOption.
type MembershipOption struct {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 05c4b532a..d876005fe 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -277,7 +277,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
}
// Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -471,7 +471,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto)
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 21008d089..8a7909246 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -71,7 +71,7 @@ func (e *endpoint) afterLoad() {
var err *tcpip.Error
if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto)
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
if err != nil {
panic(*err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ae99f0f8e..fc4f82402 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1091,7 +1091,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 87e988afa..a42e09b8c 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -307,6 +307,7 @@ func loadError(s string) *tcpip.Error {
var errors = []*tcpip.Error{
tcpip.ErrUnknownProtocol,
tcpip.ErrUnknownNICID,
+ tcpip.ErrUnknownDevice,
tcpip.ErrUnknownProtocolOption,
tcpip.ErrDuplicateNICID,
tcpip.ErrDuplicateAddress,
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 8ccb79c48..d271490c1 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -27,6 +27,7 @@ go_library(
imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4108cb09c..3693abae5 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -81,6 +81,7 @@ type endpoint struct {
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
+ multicastLoop bool
reusePort bool
broadcast bool
@@ -124,6 +125,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
//
// Linux defaults to TTL=1.
multicastTTL: 1,
+ multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
@@ -274,7 +276,7 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto)
+ r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
return stack.Route{}, 0, 0, err
}
@@ -458,13 +460,19 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.AddMembershipOption:
nicID := v.NIC
- if v.InterfaceAddr != header.IPv4Any {
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
}
if nicID == 0 {
- // TODO: Allow adding memberships without
- // specifing an interface.
- return tcpip.ErrNoRoute
+ return tcpip.ErrUnknownDevice
}
// TODO: check that v.MulticastAddr is a multicast address.
@@ -479,11 +487,19 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.RemoveMembershipOption:
nicID := v.NIC
- if v.InterfaceAddr != header.IPv4Any {
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
}
if nicID == 0 {
- return tcpip.ErrNoRoute
+ return tcpip.ErrUnknownDevice
}
// TODO: check that v.MulticastAddr is a multicast address.
@@ -503,6 +519,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
}
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = bool(v)
+ e.mu.Unlock()
+
case tcpip.ReusePortOption:
e.mu.Lock()
e.reusePort = v != 0
@@ -578,6 +599,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+
+ *o = tcpip.MulticastLoopOption(v)
+ return nil
+
case *tcpip.ReusePortOption:
e.mu.RLock()
v := e.reusePort
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 4d8210294..b2daaf751 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -82,7 +82,7 @@ func (e *endpoint) afterLoad() {
var err *tcpip.Error
if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto)
+ e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(*err)
}