diff options
Diffstat (limited to 'pkg/tcpip/network/ip')
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 145 |
1 files changed, 78 insertions, 67 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index e308550c4..bc7f7f637 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -138,76 +138,89 @@ type MulticastGroupProtocol interface { // IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state // machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. // +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + opts GenericMulticastProtocolOptions - mu struct { - sync.RWMutex + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - } + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex } // Init initializes the Generic Multicast Protocol state. -func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { - g.mu.Lock() - defer g.mu.Unlock() +// +// Must only be called once for the lifetime of g. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + g.opts = opts - g.mu.memberships = make(map[tcpip.Address]multicastGroupState) + g.memberships = make(map[tcpip.Address]multicastGroupState) + g.protocolMU = protocolMU } -// MakeAllNonMember transitions all groups to the non-member state. +// MakeAllNonMemberLocked transitions all groups to the non-member state. // // The groups will still be considered joined locally. -func (g *GenericMulticastProtocolState) MakeAllNonMember() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.transitionToNonMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// InitializeGroups initializes each group, as if they were newly joined but -// without affecting the groups' join count. +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. // // Must only be called after calling MakeAllNonMember as a group should not be // initialized while it is not in the non-member state. -func (g *GenericMulticastProtocolState) InitializeGroups() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.initializeNewMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// JoinGroup handles joining a new group. +// JoinGroupLocked handles joining a new group. // // If dontInitialize is true, the group will be not be initialized and will be // left in the non-member state - no packets will be sent for it until it is // initialized via InitializeGroups. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { - g.mu.Lock() - defer g.mu.Unlock() - - if info, ok := g.mu.memberships[groupAddress]; ok { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) { + if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return } @@ -217,15 +230,15 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do // The state will be updated below, if required. state: nonMember, lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { - info, ok := g.mu.memberships[groupAddress] + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info }), } @@ -233,25 +246,24 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do g.initializeNewMemberLocked(groupAddress, &info) } - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } -// IsLocallyJoined returns true if the group is locally joined. -func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { - g.mu.RLock() - defer g.mu.RUnlock() - _, ok := g.mu.memberships[groupAddress] +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] return ok } -// LeaveGroup handles leaving the group. +// LeaveGroupLocked handles leaving the group. // // Returns false if the group is not currently joined. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - - info, ok := g.mu.memberships[groupAddress] +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] if !ok { return false } @@ -262,30 +274,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b info.joins-- if info.joins != 0 { // If we still have outstanding joins, then do nothing further. - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return true } g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.mu.memberships, groupAddress) + delete(g.memberships, groupAddress) return true } -// HandleQuery handles a query message with the specified maximum response time. +// HandleQueryLocked handles a query message with the specified maximum response +// time. // // If the group address is unspecified, then reports will be scheduled for all // joined groups. // // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. -func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 2.4 (for IGMPv2), // // In a Membership Query message, the group address field is set to zero @@ -299,28 +311,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // when sending a Multicast-Address-Specific Query. if groupAddress.Unspecified() { // This is a general query as the group address is unspecified. - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } - } else if info, ok := g.mu.memberships[groupAddress]; ok { + } else if info, ok := g.memberships[groupAddress]; ok { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// HandleReport handles a report message. +// HandleReportLocked handles a report message. // // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. -func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), // // If the host receives another host's Report (version 1 or 2) while it has @@ -333,17 +344,17 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } // initializeNewMemberLocked initializes a new group membership. // -// Precondition: g.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) @@ -465,7 +476,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres // transitionToNonMemberLocked transitions the given multicast group the the // non-member/listener state. // -// Precondition: e.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state == nonMember { return @@ -479,7 +490,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress // setDelayTimerForAddressRLocked sets timer to send a delay report. // -// Precondition: g.mu MUST be read locked. +// Precondition: g.protocolMU MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { if info.state == nonMember { return |