summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorChris Kuiper <ckuiper@google.com>2019-05-02 19:39:55 -0700
committerShentubot <shentubot@google.com>2019-05-02 19:41:00 -0700
commit8972e47a2edb01d66c2fc6373a5663b68e3da82c (patch)
tree28195396bd4e663cfd2886ffbfd93f8d66af24dd /pkg/tcpip/transport
parent5f8225c009fcf297139c54c7b329da4aff679ece (diff)
Support reception of multicast data on more than one socket
This requires two changes: 1) Support for more than one socket to join a given multicast group. 2) Duplicate delivery of incoming multicast packets to all sockets listening for it. In addition, I tweaked the code (and added a test) to disallow duplicates IP_ADD_MEMBERSHIP calls for the same group and NIC. This is how Linux does it. PiperOrigin-RevId: 246437315 Change-Id: Icad8300b4a8c3f501d9b4cd283bd3beabef88b72
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go12
2 files changed, 33 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index db65a4e88..0ed0902b0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
e.mu.Lock()
defer e.mu.Unlock()
- e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
case tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
@@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
+
for i, mem := range e.multicastMemberships {
- if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
- // Only remove the first match, so that each added membership above is
- // paired with exactly 1 removal.
- e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ if mem == memToRemove {
+ memToRemoveIndex = i
break
}
}
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
case tcpip.MulticastLoopOption:
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 163dcbc13..74e8e9fd5 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -66,6 +66,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
if e.state != stateBound && e.state != stateConnected {
return
}
@@ -103,10 +109,4 @@ func (e *endpoint) afterLoad() {
if err != nil {
panic(*err)
}
-
- for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
}