diff options
author | Ian Gudger <igudger@google.com> | 2019-03-08 15:48:16 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-03-08 15:49:17 -0800 |
commit | 56a61282953b46c8f8b707d5948a2d3958dced0c (patch) | |
tree | 2336f9f92e227d5f43bbad81cee80c527573a6a2 /pkg/tcpip | |
parent | 832589cb076a638ca53076ebb66afb9fac4597d1 (diff) |
Implement IP_MULTICAST_LOOP.
IP_MULTICAST_LOOP controls whether or not multicast packets sent on the default
route are looped back. In order to implement this switch, support for sending
and looping back multicast packets on the default route had to be implemented.
For now we only support IPv4 multicast.
PiperOrigin-RevId: 237534603
Change-Id: I490ac7ff8e8ebef417c7eb049a919c29d156ac1c
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 43 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 2 |
19 files changed, 155 insertions, 57 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ed39640c1..5ab542f2c 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 97a43aece..7eb0e697d 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -177,7 +177,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber) + return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { @@ -191,7 +191,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber) + return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } func TestIPv4Send(t *testing.T) { @@ -221,7 +221,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil { + if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -450,7 +450,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil { + if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bfc3c08fa..545684032 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -104,7 +104,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payload.Size()) id := uint32(0) @@ -123,8 +123,19 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - r.Stats().IP.PacketsSent.Increment() + if loop&stack.PacketLoop != 0 { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + e.HandlePacket(r, vv) + } + if loop&stack.PacketOut == 0 { + return nil + } + + r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 797176243..15574bab1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -161,7 +161,7 @@ func (c *testContext) cleanup() { func TestLinkResolution(t *testing.T) { c := newTestContext(t) defer c.cleanup() - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber) + r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatal(err) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 5f68ef7d5..df3b64c98 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -84,7 +84,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { length := uint16(hdr.UsedLength() + payload.Size()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -94,8 +94,19 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - r.Stats().IP.PacketsSent.Increment() + if loop&stack.PacketLoop != 0 { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + e.HandlePacket(r, vv) + } + if loop&stack.PacketOut == 0 { + return nil + } + + r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 79f845225..14267bb48 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -28,10 +28,11 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + loopback bool demux *transportDemuxer @@ -62,12 +63,13 @@ const ( NeverPrimaryEndpoint ) -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { return &NIC{ stack: stack, id: id, name: name, linkEP: ep, + loopback: loopback, demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), @@ -407,7 +409,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.mu.RLock() for _, ref := range n.endpoints { if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) + r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* multicastLoop */) r.RemoteLinkAddress = remote ref.ep.HandlePacket(&r, vv) ref.decRef() @@ -418,7 +420,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr } if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) + r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* multicastLoop */) r.RemoteLinkAddress = remote ref.ep.HandlePacket(&r, vv) ref.decRef() @@ -430,7 +432,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // // TODO: Should we be forwarding the packet even if promiscuous? if n.stack.Forwarding() { - r, err := n.stack.FindRoute(0, "", dst, protocol) + r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidAddressesReceived.Increment() return diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 62acd5919..cf4d52fe9 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -125,6 +125,18 @@ type TransportDispatcher interface { DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) } +// PacketLooping specifies where an outbound packet should be sent. +type PacketLooping byte + +const ( + // PacketOut indicates that the packet should be passed to the link + // endpoint. + PacketOut PacketLooping = 1 << iota + + // PacketLoop indicates that the packet should be handled locally. + PacketLoop +) + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { @@ -149,7 +161,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. - WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error + WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 2b4185014..c9603ad5e 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -46,17 +46,20 @@ type Route struct { // ref a reference to the network endpoint through which the route // starts. ref *referencedNetworkEndpoint + + multicastLoop bool } // makeRoute initializes a new route. It takes ownership of the provided // reference to a network endpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, multicastLoop bool) Route { return Route{ NetProto: netProto, LocalAddress: localAddr, LocalLinkAddress: localLinkAddr, RemoteAddress: remoteAddr, ref: ref, + multicastLoop: multicastLoop, } } @@ -134,7 +137,12 @@ func (r *Route) IsResolutionRequired() bool { // WritePacket writes the packet through the given route. func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { - err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl) + loop := PacketOut + if r.multicastLoop && (header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress)) { + loop |= PacketLoop + } + + err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl, loop) if err == tcpip.ErrNoRoute { r.Stats().IP.OutgoingPacketErrors.Increment() } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cfda7ec3c..047b704e0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -513,7 +513,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error { +func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { ep := FindLinkEndpoint(linkEP) if ep == nil { return tcpip.ErrBadLinkEndpoint @@ -527,7 +527,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep) + n := newNIC(s, id, name, ep, loopback) s.nics[id] = n if enabled { @@ -539,26 +539,32 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint // CreateNIC creates a NIC with the provided id and link-layer endpoint. func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true) + return s.createNIC(id, "", linkEP, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true) + return s.createNIC(id, name, linkEP, true, false) +} + +// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer +// endpoint, and a human-readable name. +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, name, linkEP, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false) + return s.createNIC(id, "", linkEP, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false) + return s.createNIC(id, name, linkEP, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -748,7 +754,7 @@ func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.Netwo // FindRoute creates a route to the given destination address, leaving through // the given nic and local address (if provided). -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -758,7 +764,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref), nil + return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, multicastLoop && !nic.loopback), nil } } } else { @@ -774,7 +780,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n remoteAddr = ref.ep.ID().LocalAddress } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref) + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, multicastLoop && !nic.loopback) if needRoute { r.NextHop = route.Gateway } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index aba1e984c..b366de21d 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -112,7 +112,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.linkEP.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -122,6 +122,18 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] b[2] = byte(protocol) + + if loop&stack.PacketLoop != 0 { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + f.HandlePacket(r, vv) + } + if loop&stack.PacketOut == 0 { + return nil + } + return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber) } @@ -262,7 +274,7 @@ func TestNetworkReceive(t *testing.T) { } func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) { - r, err := s.FindRoute(0, "", addr, fakeNetNumber) + r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute failed: %v", err) } @@ -354,7 +366,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber) + r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute failed: %v", err) } @@ -371,7 +383,7 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, } func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber) + _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != tcpip.ErrNoRoute { t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err) } @@ -514,7 +526,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { } // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber) + r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute failed: %v", err) } @@ -584,7 +596,7 @@ func TestPromiscuousMode(t *testing.T) { } // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber) + _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) if err != tcpip.ErrNoRoute { t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err) } @@ -622,7 +634,7 @@ func TestAddressSpoofing(t *testing.T) { // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber) + r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } @@ -632,7 +644,7 @@ func TestAddressSpoofing(t *testing.T) { if err := s.SetSpoofing(1, true); err != nil { t.Fatalf("SetSpoofing failed: %v", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber) + r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute failed: %v", err) } @@ -654,14 +666,14 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber); err != tcpip.ErrNetworkUnreachable { + if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err) } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber) + r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -675,7 +687,7 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { } // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber); err != tcpip.ErrNetworkUnreachable { + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -738,7 +750,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber); err != want { + if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } @@ -746,7 +758,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) } - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber); tc.routeNeeded { + if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { // Route table is empty but we need a route, this should cause an error. if err != tcpip.ErrNoRoute { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) @@ -763,7 +775,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber); err != want { + if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a9e844e3d..279ab3c56 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -103,7 +103,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber) + r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return tcpip.ErrNoRoute } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 7010d1b68..825854148 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -68,6 +68,7 @@ func (e *Error) IgnoreStats() bool { var ( ErrUnknownProtocol = &Error{msg: "unknown protocol"} ErrUnknownNICID = &Error{msg: "unknown nic id"} + ErrUnknownDevice = &Error{msg: "unknown device"} ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} ErrDuplicateNICID = &Error{msg: "duplicate nic id"} ErrDuplicateAddress = &Error{msg: "duplicate address"} @@ -477,6 +478,10 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } +// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether +// multicast packets sent over a non-loopback interface will be looped back. +type MulticastLoopOption bool + // MembershipOption is used by SetSockOpt/GetSockOpt as an argument to // AddMembershipOption and RemoveMembershipOption. type MembershipOption struct { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 05c4b532a..d876005fe 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -277,7 +277,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto) + r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -471,7 +471,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto) + r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 21008d089..8a7909246 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -71,7 +71,7 @@ func (e *endpoint) afterLoad() { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto) + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) if err != nil { panic(*err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ae99f0f8e..fc4f82402 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1091,7 +1091,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 87e988afa..a42e09b8c 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -307,6 +307,7 @@ func loadError(s string) *tcpip.Error { var errors = []*tcpip.Error{ tcpip.ErrUnknownProtocol, tcpip.ErrUnknownNICID, + tcpip.ErrUnknownDevice, tcpip.ErrUnknownProtocolOption, tcpip.ErrDuplicateNICID, tcpip.ErrDuplicateAddress, diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 8ccb79c48..d271490c1 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -27,6 +27,7 @@ go_library( imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4108cb09c..3693abae5 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -81,6 +81,7 @@ type endpoint struct { multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID + multicastLoop bool reusePort bool broadcast bool @@ -124,6 +125,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite // // Linux defaults to TTL=1. multicastTTL: 1, + multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } @@ -274,7 +276,7 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto) + r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, 0, err } @@ -458,13 +460,19 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { case tcpip.AddMembershipOption: nicID := v.NIC - if v.InterfaceAddr != header.IPv4Any { + if v.InterfaceAddr == header.IPv4Any { + if nicID == 0 { + r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) } if nicID == 0 { - // TODO: Allow adding memberships without - // specifing an interface. - return tcpip.ErrNoRoute + return tcpip.ErrUnknownDevice } // TODO: check that v.MulticastAddr is a multicast address. @@ -479,11 +487,19 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { case tcpip.RemoveMembershipOption: nicID := v.NIC - if v.InterfaceAddr != header.IPv4Any { + if v.InterfaceAddr == header.IPv4Any { + if nicID == 0 { + r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrNoRoute + return tcpip.ErrUnknownDevice } // TODO: check that v.MulticastAddr is a multicast address. @@ -503,6 +519,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = bool(v) + e.mu.Unlock() + case tcpip.ReusePortOption: e.mu.Lock() e.reusePort = v != 0 @@ -578,6 +599,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + + *o = tcpip.MulticastLoopOption(v) + return nil + case *tcpip.ReusePortOption: e.mu.RLock() v := e.reusePort diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4d8210294..b2daaf751 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -82,7 +82,7 @@ func (e *endpoint) afterLoad() { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto) + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) if err != nil { panic(*err) } |