From 8972e47a2edb01d66c2fc6373a5663b68e3da82c Mon Sep 17 00:00:00 2001
From: Chris Kuiper <ckuiper@google.com>
Date: Thu, 2 May 2019 19:39:55 -0700
Subject: Support reception of multicast data on more than one socket

This requires two changes:
1) Support for more than one socket to join a given multicast group.

2) Duplicate delivery of incoming multicast packets to all sockets listening
for it.

In addition, I tweaked the code (and added a test) to disallow duplicates
IP_ADD_MEMBERSHIP calls for the same group and NIC. This is how Linux does
it.

PiperOrigin-RevId: 246437315
Change-Id: Icad8300b4a8c3f501d9b4cd283bd3beabef88b72
---
 pkg/tcpip/transport/udp/endpoint.go       | 39 +++++++++++++++++++++----------
 pkg/tcpip/transport/udp/endpoint_state.go | 12 +++++-----
 2 files changed, 33 insertions(+), 18 deletions(-)

(limited to 'pkg/tcpip/transport/udp')

diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index db65a4e88..0ed0902b0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
 			return tcpip.ErrUnknownDevice
 		}
 
-		if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
-			return err
-		}
+		memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
 
 		e.mu.Lock()
 		defer e.mu.Unlock()
 
-		e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+		for _, mem := range e.multicastMemberships {
+			if mem == memToInsert {
+				return tcpip.ErrPortInUse
+			}
+		}
+
+		if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+			return err
+		}
+
+		e.multicastMemberships = append(e.multicastMemberships, memToInsert)
 
 	case tcpip.RemoveMembershipOption:
 		if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
@@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
 			return tcpip.ErrUnknownDevice
 		}
 
-		if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
-			return err
-		}
+		memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+		memToRemoveIndex := -1
 
 		e.mu.Lock()
 		defer e.mu.Unlock()
+
 		for i, mem := range e.multicastMemberships {
-			if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
-				// Only remove the first match, so that each added membership above is
-				// paired with exactly 1 removal.
-				e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
-				e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+			if mem == memToRemove {
+				memToRemoveIndex = i
 				break
 			}
 		}
+		if memToRemoveIndex == -1 {
+			return tcpip.ErrBadLocalAddress
+		}
+
+		if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+			return err
+		}
+
+		e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+		e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
 
 	case tcpip.MulticastLoopOption:
 		e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 163dcbc13..74e8e9fd5 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -66,6 +66,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
 func (e *endpoint) afterLoad() {
 	e.stack = stack.StackFromEnv
 
+	for _, m := range e.multicastMemberships {
+		if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+			panic(err)
+		}
+	}
+
 	if e.state != stateBound && e.state != stateConnected {
 		return
 	}
@@ -103,10 +109,4 @@ func (e *endpoint) afterLoad() {
 	if err != nil {
 		panic(*err)
 	}
-
-	for _, m := range e.multicastMemberships {
-		if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
-			panic(err)
-		}
-	}
 }
-- 
cgit v1.2.3