diff options
author | Tamir Duberstein <tamird@google.com> | 2020-09-14 15:19:50 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-14 15:22:00 -0700 |
commit | 2747030ec7453219c09b34f773b7d2ba6a7fc552 (patch) | |
tree | c2d60dc1d9c5136c6fd23513d2995dae8a763b46 /pkg/tcpip/transport/udp | |
parent | 05d2ebee5e4ebc31cd71f6064ca433a58692be76 (diff) |
Store multicast memberships in a set
This is simpler and more performant.
PiperOrigin-RevId: 331639978
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 2 |
2 files changed, 16 insertions, 25 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 2828b2c01..b572c39db 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -139,7 +139,7 @@ type endpoint struct { // multicastMemberships that need to be remvoed when the endpoint is // closed. Protected by the mu mutex. - multicastMemberships []multicastMembership + multicastMemberships map[multicastMembership]struct{} // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -182,12 +182,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // TTL=1. // // Linux defaults to TTL=1. - multicastTTL: 1, - multicastLoop: true, - rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, - state: StateInitial, - uniqueID: s.UniqueID(), + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + multicastMemberships: make(map[multicastMembership]struct{}), + state: StateInitial, + uniqueID: s.UniqueID(), } // Override with stack defaults. @@ -237,10 +238,10 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } - for _, mem := range e.multicastMemberships { + for mem := range e.multicastMemberships { e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } - e.multicastMemberships = nil + e.multicastMemberships = make(map[multicastMembership]struct{}) // Close the receive list and drain it. e.rcvMu.Lock() @@ -752,17 +753,15 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - for _, mem := range e.multicastMemberships { - if mem == memToInsert { - return tcpip.ErrPortInUse - } + if _, ok := e.multicastMemberships[memToInsert]; ok { + return tcpip.ErrPortInUse } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } - e.multicastMemberships = append(e.multicastMemberships, memToInsert) + e.multicastMemberships[memToInsert] = struct{}{} case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { @@ -786,18 +785,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() - for i, mem := range e.multicastMemberships { - if mem == memToRemove { - memToRemoveIndex = i - break - } - } - if memToRemoveIndex == -1 { + if _, ok := e.multicastMemberships[memToRemove]; !ok { return tcpip.ErrBadLocalAddress } @@ -805,8 +797,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return err } - e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + delete(e.multicastMemberships, memToRemove) case *tcpip.BindToDeviceOption: id := tcpip.NICID(*v) diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 851e6b635..858c99a45 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - for _, m := range e.multicastMemberships { + for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } |