diff options
Diffstat (limited to 'pkg/tcpip/network/ip')
-rw-r--r-- | pkg/tcpip/network/ip/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 145 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol_test.go | 412 |
3 files changed, 364 insertions, 194 deletions
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index 6ca200b48..ca1247c1e 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -18,6 +18,7 @@ go_test( srcs = ["generic_multicast_protocol_test.go"], deps = [ ":ip", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/faketime", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index e308550c4..bc7f7f637 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -138,76 +138,89 @@ type MulticastGroupProtocol interface { // IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state // machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. // +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + opts GenericMulticastProtocolOptions - mu struct { - sync.RWMutex + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - } + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex } // Init initializes the Generic Multicast Protocol state. -func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { - g.mu.Lock() - defer g.mu.Unlock() +// +// Must only be called once for the lifetime of g. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + g.opts = opts - g.mu.memberships = make(map[tcpip.Address]multicastGroupState) + g.memberships = make(map[tcpip.Address]multicastGroupState) + g.protocolMU = protocolMU } -// MakeAllNonMember transitions all groups to the non-member state. +// MakeAllNonMemberLocked transitions all groups to the non-member state. // // The groups will still be considered joined locally. -func (g *GenericMulticastProtocolState) MakeAllNonMember() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.transitionToNonMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// InitializeGroups initializes each group, as if they were newly joined but -// without affecting the groups' join count. +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. // // Must only be called after calling MakeAllNonMember as a group should not be // initialized while it is not in the non-member state. -func (g *GenericMulticastProtocolState) InitializeGroups() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.initializeNewMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// JoinGroup handles joining a new group. +// JoinGroupLocked handles joining a new group. // // If dontInitialize is true, the group will be not be initialized and will be // left in the non-member state - no packets will be sent for it until it is // initialized via InitializeGroups. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { - g.mu.Lock() - defer g.mu.Unlock() - - if info, ok := g.mu.memberships[groupAddress]; ok { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) { + if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return } @@ -217,15 +230,15 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do // The state will be updated below, if required. state: nonMember, lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { - info, ok := g.mu.memberships[groupAddress] + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info }), } @@ -233,25 +246,24 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do g.initializeNewMemberLocked(groupAddress, &info) } - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } -// IsLocallyJoined returns true if the group is locally joined. -func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { - g.mu.RLock() - defer g.mu.RUnlock() - _, ok := g.mu.memberships[groupAddress] +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] return ok } -// LeaveGroup handles leaving the group. +// LeaveGroupLocked handles leaving the group. // // Returns false if the group is not currently joined. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - - info, ok := g.mu.memberships[groupAddress] +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] if !ok { return false } @@ -262,30 +274,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b info.joins-- if info.joins != 0 { // If we still have outstanding joins, then do nothing further. - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return true } g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.mu.memberships, groupAddress) + delete(g.memberships, groupAddress) return true } -// HandleQuery handles a query message with the specified maximum response time. +// HandleQueryLocked handles a query message with the specified maximum response +// time. // // If the group address is unspecified, then reports will be scheduled for all // joined groups. // // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. -func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 2.4 (for IGMPv2), // // In a Membership Query message, the group address field is set to zero @@ -299,28 +311,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // when sending a Multicast-Address-Specific Query. if groupAddress.Unspecified() { // This is a general query as the group address is unspecified. - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } - } else if info, ok := g.mu.memberships[groupAddress]; ok { + } else if info, ok := g.memberships[groupAddress]; ok { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// HandleReport handles a report message. +// HandleReportLocked handles a report message. // // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. -func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), // // If the host receives another host's Report (version 1 or 2) while it has @@ -333,17 +344,17 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } // initializeNewMemberLocked initializes a new group membership. // -// Precondition: g.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) @@ -465,7 +476,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres // transitionToNonMemberLocked transitions the given multicast group the the // non-member/listener state. // -// Precondition: e.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state == nonMember { return @@ -479,7 +490,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress // setDelayTimerForAddressRLocked sets timer to send a delay report. // -// Precondition: g.mu MUST be read locked. +// Precondition: g.protocolMU MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { if info.state == nonMember { return diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 670be30d4..6fd0eb9f7 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/ip" @@ -37,41 +38,86 @@ const ( var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) type mockMulticastGroupProtocol struct { + t *testing.T + + mu sync.RWMutex + + // Must only be accessed with mu held. sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + sendLeaveGroupAddrCount map[tcpip.Address]int } func (m *mockMulticastGroupProtocol) init() { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() +} + +func (m *mockMulticastGroupProtocol) initLocked() { m.sendReportGroupAddrCount = make(map[tcpip.Address]int) m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + m.sendReportGroupAddrCount[groupAddress]++ return nil } func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + m.sendLeaveGroupAddrCount[groupAddress]++ return nil } -func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - sendReportGroupAddressesMap := make(map[tcpip.Address]int) +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendReportGroupAddresses { - sendReportGroupAddressesMap[a] = 1 + sendReportGroupAddrCount[a] = 1 } - sendLeaveGroupAddressesMap := make(map[tcpip.Address]int) + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddressesMap[a] = 1 + sendLeaveGroupAddrCount[a] = 1 } - diff := cmp.Diff(mockMulticastGroupProtocol{ - sendReportGroupAddrCount: sendReportGroupAddressesMap, - sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap, - }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{})) - mgp.init() + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + return p.Last().String() == ".mu" || p.Last().String() == ".t" + }, + cmp.Ignore(), + ), + ) + m.initLocked() return diff } @@ -96,10 +142,11 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(0)), Clock: clock, @@ -110,21 +157,24 @@ func TestJoinGroup(t *testing.T) { // Joining a group should send a report immediately and another after // a random interval between 0 and the maximum unsolicited report delay. - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.mu.Lock() + g.JoinGroupLocked(test.addr, false /* dontInitialize */) + mgp.mu.Unlock() if test.shouldSendReports { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -152,10 +202,11 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(1)), Clock: clock, @@ -164,27 +215,36 @@ func TestLeaveGroup(t *testing.T) { AllNodesAddress: addr2, }) - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.mu.Lock() + g.JoinGroupLocked(test.addr, false /* dontInitialize */) + mgp.mu.Unlock() if test.shouldSendMessages { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Leaving a group should send a leave report immediately and cancel any // delayed reports. - if !g.LeaveGroup(test.addr) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr) + { + mgp.mu.Lock() + res := g.LeaveGroupLocked(test.addr) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr) + } } if test.shouldSendMessages { - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -227,10 +287,11 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(2)), Clock: clock, @@ -239,32 +300,41 @@ func TestHandleReport(t *testing.T) { AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a report for a group we have a timer scheduled for should // cancel our delayed report timer for the group. - g.HandleReport(test.reportAddr) + mgp.mu.Lock() + g.HandleReportLocked(test.reportAddr) + mgp.mu.Unlock() if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -313,10 +383,11 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(3)), Clock: clock, @@ -325,36 +396,45 @@ func TestHandleQuery(t *testing.T) { AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a query should make us schedule a new delayed report if it // is a query directed at us or a general query. - g.HandleQuery(test.queryAddr, test.maxDelay) + mgp.mu.Lock() + g.HandleQueryLocked(test.queryAddr, test.maxDelay) + mgp.mu.Unlock() if len(test.expectReportsFor) != 0 { clock.Advance(test.maxDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -363,10 +443,11 @@ func TestHandleQuery(t *testing.T) { func TestJoinCount(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(4)), Clock: clock, @@ -375,70 +456,110 @@ func TestJoinCount(t *testing.T) { }) // Set the join count to 2 for a group. - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } } // Only the first join should trigger a report to be sent. - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if t.Failed() { + t.FailNow() } // Group should still be considered joined after leaving once. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) - } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if !isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } } // A leave report should only be sent once the join count reaches 0. - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Leaving once more should actually remove us from the group. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if t.Failed() { + t.FailNow() } // Group should no longer be joined so we should not have anything to // leave. - if g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1) - } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } func TestMakeAllNonMemberAndInitialize(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(3)), Clock: clock, @@ -447,48 +568,62 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should send the leave reports for each but still consider them locally // joined. - g.MakeAllNonMember() - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.MakeAllNonMemberLocked() + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !g.IsLocallyJoined(group) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group) + mgp.mu.RLock() + res := g.IsLocallyJoinedRLocked(group) + mgp.mu.RUnlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group) } } // Should send the initial set of unsolcited reports. - g.InitializeGroups() - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.InitializeGroupsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } @@ -521,10 +656,11 @@ func TestGroupStateNonMember(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: test.enabled, Rand: rand.New(rand.NewSource(3)), Clock: clock, @@ -532,43 +668,65 @@ func TestGroupStateNonMember(t *testing.T) { MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, }) - g.JoinGroup(addr1, test.dontInitialize) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + // Joining groups should not send any reports. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, test.dontInitialize) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - - g.JoinGroup(addr2, test.dontInitialize) - if !g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2) + { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, test.dontInitialize) + res := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.HandleQuery(addr1, time.Nanosecond) + // Receiving a query should not send any reports. + mgp.mu.Lock() + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Nanosecond) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if !g.LeaveGroup(addr2) { - t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2) - } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) - } - if g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2) + // Leaving groups should not send any leave messages. + { + mgp.mu.Lock() + addr2LeaveRes := g.LeaveGroupLocked(addr2) + addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) + addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !addr2LeaveRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) + } + if !addr1IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + if addr2IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) |