From 56a61282953b46c8f8b707d5948a2d3958dced0c Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 8 Mar 2019 15:48:16 -0800 Subject: Implement IP_MULTICAST_LOOP. IP_MULTICAST_LOOP controls whether or not multicast packets sent on the default route are looped back. In order to implement this switch, support for sending and looping back multicast packets on the default route had to be implemented. For now we only support IPv4 multicast. PiperOrigin-RevId: 237534603 Change-Id: I490ac7ff8e8ebef417c7eb049a919c29d156ac1c --- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ip_test.go | 8 ++++---- pkg/tcpip/network/ipv4/ipv4.go | 15 +++++++++++++-- pkg/tcpip/network/ipv6/icmp_test.go | 2 +- pkg/tcpip/network/ipv6/ipv6.go | 15 +++++++++++++-- 5 files changed, 32 insertions(+), 10 deletions(-) (limited to 'pkg/tcpip/network') diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ed39640c1..5ab542f2c 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 97a43aece..7eb0e697d 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -177,7 +177,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber) + return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { @@ -191,7 +191,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber) + return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } func TestIPv4Send(t *testing.T) { @@ -221,7 +221,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil { + if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -450,7 +450,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil { + if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bfc3c08fa..545684032 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -104,7 +104,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payload.Size()) id := uint32(0) @@ -123,8 +123,19 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - r.Stats().IP.PacketsSent.Increment() + if loop&stack.PacketLoop != 0 { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + e.HandlePacket(r, vv) + } + if loop&stack.PacketOut == 0 { + return nil + } + + r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 797176243..15574bab1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -161,7 +161,7 @@ func (c *testContext) cleanup() { func TestLinkResolution(t *testing.T) { c := newTestContext(t) defer c.cleanup() - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber) + r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatal(err) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 5f68ef7d5..df3b64c98 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -84,7 +84,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { length := uint16(hdr.UsedLength() + payload.Size()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -94,8 +94,19 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - r.Stats().IP.PacketsSent.Increment() + if loop&stack.PacketLoop != 0 { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + e.HandlePacket(r, vv) + } + if loop&stack.PacketOut == 0 { + return nil + } + + r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) } -- cgit v1.2.3