diff options
Diffstat (limited to 'pkg/tcpip/transport/udp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index db65a4e88..0ed0902b0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { - return err - } + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} e.mu.Lock() defer e.mu.Unlock() - e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + for _, mem := range e.multicastMemberships { + if mem == memToInsert { + return tcpip.ErrPortInUse + } + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships = append(e.multicastMemberships, memToInsert) case tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { @@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { - return err - } + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { - if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { - // Only remove the first match, so that each added membership above is - // paired with exactly 1 removal. - e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + if mem == memToRemove { + memToRemoveIndex = i break } } + if memToRemoveIndex == -1 { + return tcpip.ErrBadLocalAddress + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] case tcpip.MulticastLoopOption: e.mu.Lock() |