diff options
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol_test.go | 409 |
1 files changed, 169 insertions, 240 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index f56f7aa90..85593f211 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -37,46 +37,105 @@ const ( var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) -type mockMulticastGroupProtocol struct { - t *testing.T - - mu sync.RWMutex +type mockMulticastGroupProtocolProtectedFields struct { + sync.RWMutex - // Must only be accessed with mu held. + genericMulticastGroup ip.GenericMulticastProtocolState sendReportGroupAddrCount map[tcpip.Address]int + sendLeaveGroupAddrCount map[tcpip.Address]int + makeQueuePackets bool + disabled bool +} - // Must only be accessed with mu held. - sendLeaveGroupAddrCount map[tcpip.Address]int - - // Must only be accessed with mu held. - makeQueuePackets bool +type mockMulticastGroupProtocol struct { + t *testing.T - // Must only be accessed with mu held. - disabled bool + mu mockMulticastGroupProtocolProtectedFields } -func (m *mockMulticastGroupProtocol) init() { +func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { m.mu.Lock() defer m.mu.Unlock() m.initLocked() + opts.Protocol = m + m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) } func (m *mockMulticastGroupProtocol) initLocked() { - m.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) + m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } func (m *mockMulticastGroupProtocol) setEnabled(v bool) { m.mu.Lock() defer m.mu.Unlock() - m.disabled = !v + m.mu.disabled = !v +} + +func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.makeQueuePackets = v +} + +func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.JoinGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleReportLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) +} + +func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) +} + +func (m *mockMulticastGroupProtocol) makeAllNonMember() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.MakeAllNonMemberLocked() +} + +func (m *mockMulticastGroupProtocol) initializeGroups() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.InitializeGroupsLocked() +} + +func (m *mockMulticastGroupProtocol) sendQueuedReports() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.SendQueuedReportsLocked() } // Enabled implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be read locked. func (m *mockMulticastGroupProtocol) Enabled() bool { - return !m.disabled + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") + } + + return !m.mu.disabled } // SendReport implements ip.MulticastGroupProtocol. @@ -92,8 +151,8 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } - m.sendReportGroupAddrCount[groupAddress]++ - return !m.makeQueuePackets, nil + m.mu.sendReportGroupAddrCount[groupAddress]++ + return !m.mu.makeQueuePackets, nil } // SendLeave implements ip.MulticastGroupProtocol. @@ -109,7 +168,7 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpi m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } - m.sendLeaveGroupAddrCount[groupAddress]++ + m.mu.sendLeaveGroupAddrCount[groupAddress]++ return nil } @@ -129,16 +188,19 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr diff := cmp.Diff( &mockMulticastGroupProtocol{ - sendReportGroupAddrCount: sendReportGroupAddrCount, - sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + mu: mockMulticastGroupProtocolProtectedFields{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, }, m, cmp.AllowUnexported(mockMulticastGroupProtocol{}), + cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t cmp.FilterPath( func(p cmp.Path) bool { switch p.Last().String() { - case ".mu", ".t", ".makeQueuePackets", ".disabled": + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": return true } return false @@ -170,24 +232,19 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after // a random interval between 0 and the maximum unsolicited report delay. - mgp.mu.Lock() - g.JoinGroupLocked(test.addr) - mgp.mu.Unlock() + mgp.joinGroup(test.addr) if test.shouldSendReports { if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -229,22 +286,17 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr2, }) - mgp.mu.Lock() - g.JoinGroupLocked(test.addr) - mgp.mu.Unlock() + mgp.joinGroup(test.addr) if test.shouldSendMessages { if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -254,11 +306,9 @@ func TestLeaveGroup(t *testing.T) { // Leaving a group should send a leave report immediately and cancel any // delayed reports. { - mgp.mu.Lock() - res := g.LeaveGroupLocked(test.addr) - mgp.mu.Unlock() - if !res { - t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr) + + if !mgp.leaveGroup(test.addr) { + t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) } } if test.shouldSendMessages { @@ -313,43 +363,32 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - mgp.mu.Lock() - g.JoinGroupLocked(addr1) - mgp.mu.Unlock() + mgp.joinGroup(addr1) if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.JoinGroupLocked(addr2) - mgp.mu.Unlock() + mgp.joinGroup(addr2) if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.JoinGroupLocked(addr3) - mgp.mu.Unlock() + mgp.joinGroup(addr3) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a report for a group we have a timer scheduled for should // cancel our delayed report timer for the group. - mgp.mu.Lock() - g.HandleReportLocked(test.reportAddr) - mgp.mu.Unlock() + mgp.handleReport(test.reportAddr) if len(test.expectReportsFor) != 0 { // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) @@ -408,34 +447,25 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - mgp.mu.Lock() - g.JoinGroupLocked(addr1) - mgp.mu.Unlock() + mgp.joinGroup(addr1) if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.JoinGroupLocked(addr2) - mgp.mu.Unlock() + mgp.joinGroup(addr2) if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.JoinGroupLocked(addr3) - mgp.mu.Unlock() + mgp.joinGroup(addr3) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -447,9 +477,7 @@ func TestHandleQuery(t *testing.T) { // Receiving a query should make us schedule a new delayed report if it // is a query directed at us or a general query. - mgp.mu.Lock() - g.HandleQueryLocked(test.queryAddr, test.maxDelay) - mgp.mu.Unlock() + mgp.handleQuery(test.queryAddr, test.maxDelay) if len(test.expectReportsFor) != 0 { clock.Advance(test.maxDelay) if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { @@ -467,40 +495,27 @@ func TestHandleQuery(t *testing.T) { } func TestJoinCount(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(4)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: time.Second, }) // Set the join count to 2 for a group. - { - mgp.mu.Lock() - g.JoinGroupLocked(addr1) - res := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if !res { - t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } // Only the first join should trigger a report to be sent. if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - { - mgp.mu.Lock() - g.JoinGroupLocked(addr1) - res := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if !res { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -510,17 +525,11 @@ func TestJoinCount(t *testing.T) { } // Group should still be considered joined after leaving once. - { - mgp.mu.Lock() - leaveGroupRes := g.LeaveGroupLocked(addr1) - isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if !leaveGroupRes { - t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) - } - if !isLocallyJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) + } + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } // A leave report should only be sent once the join count reaches 0. if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { @@ -531,17 +540,11 @@ func TestJoinCount(t *testing.T) { } // Leaving once more should actually remove us from the group. - { - mgp.mu.Lock() - leaveGroupRes := g.LeaveGroupLocked(addr1) - isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if !leaveGroupRes { - t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) - } - if isLocallyJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) - } + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) } if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -552,17 +555,11 @@ func TestJoinCount(t *testing.T) { // Group should no longer be joined so we should not have anything to // leave. - { - mgp.mu.Lock() - leaveGroupRes := g.LeaveGroupLocked(addr1) - isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if leaveGroupRes { - t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1) - } - if isLocallyJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) - } + if mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) } if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -578,43 +575,32 @@ func TestJoinCount(t *testing.T) { } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - mgp.mu.Lock() - g.JoinGroupLocked(addr1) - mgp.mu.Unlock() + mgp.joinGroup(addr1) if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.JoinGroupLocked(addr2) - mgp.mu.Unlock() + mgp.joinGroup(addr2) if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.JoinGroupLocked(addr3) - mgp.mu.Unlock() + mgp.joinGroup(addr3) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should send the leave reports for each but still consider them locally // joined. - mgp.mu.Lock() - g.MakeAllNonMemberLocked() - mgp.mu.Unlock() + mgp.makeAllNonMember() if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -624,18 +610,13 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } for _, group := range []tcpip.Address{addr1, addr2, addr3} { - mgp.mu.RLock() - res := g.IsLocallyJoinedRLocked(group) - mgp.mu.RUnlock() - if !res { - t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group) + if !mgp.isLocallyJoined(group) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) } } // Should send the initial set of unsolcited reports. - mgp.mu.Lock() - g.InitializeGroupsLocked() - mgp.mu.Unlock() + mgp.initializeGroups() if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -654,49 +635,34 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { // TestGroupStateNonMember tests that groups do not send packets when in the // non-member state, but are still considered locally joined. func TestGroupStateNonMember(t *testing.T) { - var g ip.GenericMulticastProtocolState mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - mgp.init() - mgp.setEnabled(false) - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, }) + mgp.setEnabled(false) // Joining groups should not send any reports. - { - mgp.mu.Lock() - g.JoinGroupLocked(addr1) - res := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if !res { - t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - { - mgp.mu.Lock() - g.JoinGroupLocked(addr2) - res := g.IsLocallyJoinedRLocked(addr2) - mgp.mu.Unlock() - if !res { - t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) - } + mgp.joinGroup(addr2) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) } if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a query should not send any reports. - mgp.mu.Lock() - g.HandleQueryLocked(addr1, time.Nanosecond) - mgp.mu.Unlock() + mgp.handleQuery(addr1, time.Nanosecond) // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Nanosecond) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { @@ -704,21 +670,11 @@ func TestGroupStateNonMember(t *testing.T) { } // Leaving groups should not send any leave messages. - { - mgp.mu.Lock() - addr2LeaveRes := g.LeaveGroupLocked(addr2) - addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) - addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) - mgp.mu.Unlock() - if !addr2LeaveRes { - t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) - } - if !addr1IsJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } - if addr2IsJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) - } + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) } if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -731,23 +687,18 @@ func TestGroupStateNonMember(t *testing.T) { } func TestQueuedPackets(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() clock := faketime.NewManualClock() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + mgp := mockMulticastGroupProtocol{t: t} + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(4)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, }) // Joining should trigger a SendReport, but mgp should report that we did not // send the packet. - mgp.mu.Lock() - mgp.makeQueuePackets = true - g.JoinGroupLocked(addr1) - mgp.mu.Unlock() + mgp.setQueuePackets(true) + mgp.joinGroup(addr1) if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -760,10 +711,8 @@ func TestQueuedPackets(t *testing.T) { } // Mock being able to successfully send the report. - mgp.mu.Lock() - mgp.makeQueuePackets = false - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.setQueuePackets(false) + mgp.sendQueuedReports() if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -775,19 +724,15 @@ func TestQueuedPackets(t *testing.T) { } // Should not have anything else to send (we should be idle). - mgp.mu.Lock() - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.sendQueuedReports() clock.Advance(time.Hour) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receive a query but mock being unable to send reports again. - mgp.mu.Lock() - mgp.makeQueuePackets = true - g.HandleQueryLocked(addr1, time.Nanosecond) - mgp.mu.Unlock() + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) clock.Advance(time.Nanosecond) if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -795,28 +740,22 @@ func TestQueuedPackets(t *testing.T) { // Mock being able to send reports again - we should have a packet queued to // send. - mgp.mu.Lock() - mgp.makeQueuePackets = false - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.setQueuePackets(false) + mgp.sendQueuedReports() if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should not have anything else to send. - mgp.mu.Lock() - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.sendQueuedReports() clock.Advance(time.Hour) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receive a query again, but mock being unable to send reports. - mgp.mu.Lock() - mgp.makeQueuePackets = true - g.HandleQueryLocked(addr1, time.Nanosecond) - mgp.mu.Unlock() + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) clock.Advance(time.Nanosecond) if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -825,10 +764,8 @@ func TestQueuedPackets(t *testing.T) { // Receiving a report should should transition us into the idle member state, // even if we had a packet queued. We should no longer have any packets to // send. - mgp.mu.Lock() - g.HandleReportLocked(addr1) - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.handleReport(addr1) + mgp.sendQueuedReports() clock.Advance(time.Hour) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -836,27 +773,21 @@ func TestQueuedPackets(t *testing.T) { // When we fail to send the initial set of reports, incoming reports should // not affect a newly joined group's reports from being sent. - mgp.mu.Lock() - mgp.makeQueuePackets = true - g.JoinGroupLocked(addr2) - mgp.mu.Unlock() + mgp.setQueuePackets(true) + mgp.joinGroup(addr2) if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - mgp.mu.Lock() - g.HandleReportLocked(addr2) + mgp.handleReport(addr2) // Attempting to send queued reports while still unable to send reports should // not change the host state. - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.sendQueuedReports() if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Mock being able to successfully send the report. - mgp.mu.Lock() - mgp.makeQueuePackets = false - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.setQueuePackets(false) + mgp.sendQueuedReports() if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -867,9 +798,7 @@ func TestQueuedPackets(t *testing.T) { } // Should not have anything else to send. - mgp.mu.Lock() - g.SendQueuedReportsLocked() - mgp.mu.Unlock() + mgp.sendQueuedReports() clock.Advance(time.Hour) if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) |