summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go409
1 files changed, 169 insertions, 240 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index f56f7aa90..85593f211 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -37,46 +37,105 @@ const (
var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
-type mockMulticastGroupProtocol struct {
- t *testing.T
-
- mu sync.RWMutex
+type mockMulticastGroupProtocolProtectedFields struct {
+ sync.RWMutex
- // Must only be accessed with mu held.
+ genericMulticastGroup ip.GenericMulticastProtocolState
sendReportGroupAddrCount map[tcpip.Address]int
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+ makeQueuePackets bool
+ disabled bool
+}
- // Must only be accessed with mu held.
- sendLeaveGroupAddrCount map[tcpip.Address]int
-
- // Must only be accessed with mu held.
- makeQueuePackets bool
+type mockMulticastGroupProtocol struct {
+ t *testing.T
- // Must only be accessed with mu held.
- disabled bool
+ mu mockMulticastGroupProtocolProtectedFields
}
-func (m *mockMulticastGroupProtocol) init() {
+func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
m.mu.Lock()
defer m.mu.Unlock()
m.initLocked()
+ opts.Protocol = m
+ m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
}
func (m *mockMulticastGroupProtocol) initLocked() {
- m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
- m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+ m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
}
func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
m.mu.Lock()
defer m.mu.Unlock()
- m.disabled = !v
+ m.mu.disabled = !v
+}
+
+func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.makeQueuePackets = v
+}
+
+func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.JoinGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleReportLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
+}
+
+func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) makeAllNonMember() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initializeGroups() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.InitializeGroupsLocked()
+}
+
+func (m *mockMulticastGroupProtocol) sendQueuedReports() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.SendQueuedReportsLocked()
}
// Enabled implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be read locked.
func (m *mockMulticastGroupProtocol) Enabled() bool {
- return !m.disabled
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
+ }
+
+ return !m.mu.disabled
}
// SendReport implements ip.MulticastGroupProtocol.
@@ -92,8 +151,8 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
- m.sendReportGroupAddrCount[groupAddress]++
- return !m.makeQueuePackets, nil
+ m.mu.sendReportGroupAddrCount[groupAddress]++
+ return !m.mu.makeQueuePackets, nil
}
// SendLeave implements ip.MulticastGroupProtocol.
@@ -109,7 +168,7 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpi
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
- m.sendLeaveGroupAddrCount[groupAddress]++
+ m.mu.sendLeaveGroupAddrCount[groupAddress]++
return nil
}
@@ -129,16 +188,19 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr
diff := cmp.Diff(
&mockMulticastGroupProtocol{
- sendReportGroupAddrCount: sendReportGroupAddrCount,
- sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ mu: mockMulticastGroupProtocolProtectedFields{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
},
m,
cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
// ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
cmp.FilterPath(
func(p cmp.Path) bool {
switch p.Last().String() {
- case ".mu", ".t", ".makeQueuePackets", ".disabled":
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
return true
}
return false
@@ -170,24 +232,19 @@ func TestJoinGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr2,
})
// Joining a group should send a report immediately and another after
// a random interval between 0 and the maximum unsolicited report delay.
- mgp.mu.Lock()
- g.JoinGroupLocked(test.addr)
- mgp.mu.Unlock()
+ mgp.joinGroup(test.addr)
if test.shouldSendReports {
if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -229,22 +286,17 @@ func TestLeaveGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr2,
})
- mgp.mu.Lock()
- g.JoinGroupLocked(test.addr)
- mgp.mu.Unlock()
+ mgp.joinGroup(test.addr)
if test.shouldSendMessages {
if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -254,11 +306,9 @@ func TestLeaveGroup(t *testing.T) {
// Leaving a group should send a leave report immediately and cancel any
// delayed reports.
{
- mgp.mu.Lock()
- res := g.LeaveGroupLocked(test.addr)
- mgp.mu.Unlock()
- if !res {
- t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr)
+
+ if !mgp.leaveGroup(test.addr) {
+ t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
}
}
if test.shouldSendMessages {
@@ -313,43 +363,32 @@ func TestHandleReport(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr3,
})
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr1)
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.JoinGroupLocked(addr2)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr2)
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.JoinGroupLocked(addr3)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr3)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a report for a group we have a timer scheduled for should
// cancel our delayed report timer for the group.
- mgp.mu.Lock()
- g.HandleReportLocked(test.reportAddr)
- mgp.mu.Unlock()
+ mgp.handleReport(test.reportAddr)
if len(test.expectReportsFor) != 0 {
// Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
@@ -408,34 +447,25 @@ func TestHandleQuery(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr3,
})
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr1)
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.JoinGroupLocked(addr2)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr2)
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.JoinGroupLocked(addr3)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr3)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -447,9 +477,7 @@ func TestHandleQuery(t *testing.T) {
// Receiving a query should make us schedule a new delayed report if it
// is a query directed at us or a general query.
- mgp.mu.Lock()
- g.HandleQueryLocked(test.queryAddr, test.maxDelay)
- mgp.mu.Unlock()
+ mgp.handleQuery(test.queryAddr, test.maxDelay)
if len(test.expectReportsFor) != 0 {
clock.Advance(test.maxDelay)
if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
@@ -467,40 +495,27 @@ func TestHandleQuery(t *testing.T) {
}
func TestJoinCount(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(4)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: time.Second,
})
// Set the join count to 2 for a group.
- {
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1)
- res := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if !res {
- t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
// Only the first join should trigger a report to be sent.
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- {
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1)
- res := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if !res {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -510,17 +525,11 @@ func TestJoinCount(t *testing.T) {
}
// Group should still be considered joined after leaving once.
- {
- mgp.mu.Lock()
- leaveGroupRes := g.LeaveGroupLocked(addr1)
- isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if !leaveGroupRes {
- t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
- }
- if !isLocallyJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
// A leave report should only be sent once the join count reaches 0.
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
@@ -531,17 +540,11 @@ func TestJoinCount(t *testing.T) {
}
// Leaving once more should actually remove us from the group.
- {
- mgp.mu.Lock()
- leaveGroupRes := g.LeaveGroupLocked(addr1)
- isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if !leaveGroupRes {
- t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
- }
- if isLocallyJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
- }
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
}
if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -552,17 +555,11 @@ func TestJoinCount(t *testing.T) {
// Group should no longer be joined so we should not have anything to
// leave.
- {
- mgp.mu.Lock()
- leaveGroupRes := g.LeaveGroupLocked(addr1)
- isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if leaveGroupRes {
- t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1)
- }
- if isLocallyJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
- }
+ if mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
}
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -578,43 +575,32 @@ func TestJoinCount(t *testing.T) {
}
func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr3,
})
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr1)
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.JoinGroupLocked(addr2)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr2)
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.JoinGroupLocked(addr3)
- mgp.mu.Unlock()
+ mgp.joinGroup(addr3)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should send the leave reports for each but still consider them locally
// joined.
- mgp.mu.Lock()
- g.MakeAllNonMemberLocked()
- mgp.mu.Unlock()
+ mgp.makeAllNonMember()
if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -624,18 +610,13 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
for _, group := range []tcpip.Address{addr1, addr2, addr3} {
- mgp.mu.RLock()
- res := g.IsLocallyJoinedRLocked(group)
- mgp.mu.RUnlock()
- if !res {
- t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group)
+ if !mgp.isLocallyJoined(group) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
}
}
// Should send the initial set of unsolcited reports.
- mgp.mu.Lock()
- g.InitializeGroupsLocked()
- mgp.mu.Unlock()
+ mgp.initializeGroups()
if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -654,49 +635,34 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
// TestGroupStateNonMember tests that groups do not send packets when in the
// non-member state, but are still considered locally joined.
func TestGroupStateNonMember(t *testing.T) {
- var g ip.GenericMulticastProtocolState
mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- mgp.init()
- mgp.setEnabled(false)
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
})
+ mgp.setEnabled(false)
// Joining groups should not send any reports.
- {
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1)
- res := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if !res {
- t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- {
- mgp.mu.Lock()
- g.JoinGroupLocked(addr2)
- res := g.IsLocallyJoinedRLocked(addr2)
- mgp.mu.Unlock()
- if !res {
- t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2)
- }
+ mgp.joinGroup(addr2)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
}
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a query should not send any reports.
- mgp.mu.Lock()
- g.HandleQueryLocked(addr1, time.Nanosecond)
- mgp.mu.Unlock()
+ mgp.handleQuery(addr1, time.Nanosecond)
// Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Nanosecond)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
@@ -704,21 +670,11 @@ func TestGroupStateNonMember(t *testing.T) {
}
// Leaving groups should not send any leave messages.
- {
- mgp.mu.Lock()
- addr2LeaveRes := g.LeaveGroupLocked(addr2)
- addr1IsJoined := g.IsLocallyJoinedRLocked(addr1)
- addr2IsJoined := g.IsLocallyJoinedRLocked(addr2)
- mgp.mu.Unlock()
- if !addr2LeaveRes {
- t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2)
- }
- if !addr1IsJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
- if addr2IsJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2)
- }
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
}
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -731,23 +687,18 @@ func TestGroupStateNonMember(t *testing.T) {
}
func TestQueuedPackets(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
clock := faketime.NewManualClock()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ mgp := mockMulticastGroupProtocol{t: t}
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(4)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
})
// Joining should trigger a SendReport, but mgp should report that we did not
// send the packet.
- mgp.mu.Lock()
- mgp.makeQueuePackets = true
- g.JoinGroupLocked(addr1)
- mgp.mu.Unlock()
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr1)
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -760,10 +711,8 @@ func TestQueuedPackets(t *testing.T) {
}
// Mock being able to successfully send the report.
- mgp.mu.Lock()
- mgp.makeQueuePackets = false
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -775,19 +724,15 @@ func TestQueuedPackets(t *testing.T) {
}
// Should not have anything else to send (we should be idle).
- mgp.mu.Lock()
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.sendQueuedReports()
clock.Advance(time.Hour)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receive a query but mock being unable to send reports again.
- mgp.mu.Lock()
- mgp.makeQueuePackets = true
- g.HandleQueryLocked(addr1, time.Nanosecond)
- mgp.mu.Unlock()
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
clock.Advance(time.Nanosecond)
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -795,28 +740,22 @@ func TestQueuedPackets(t *testing.T) {
// Mock being able to send reports again - we should have a packet queued to
// send.
- mgp.mu.Lock()
- mgp.makeQueuePackets = false
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should not have anything else to send.
- mgp.mu.Lock()
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.sendQueuedReports()
clock.Advance(time.Hour)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receive a query again, but mock being unable to send reports.
- mgp.mu.Lock()
- mgp.makeQueuePackets = true
- g.HandleQueryLocked(addr1, time.Nanosecond)
- mgp.mu.Unlock()
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
clock.Advance(time.Nanosecond)
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -825,10 +764,8 @@ func TestQueuedPackets(t *testing.T) {
// Receiving a report should should transition us into the idle member state,
// even if we had a packet queued. We should no longer have any packets to
// send.
- mgp.mu.Lock()
- g.HandleReportLocked(addr1)
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.handleReport(addr1)
+ mgp.sendQueuedReports()
clock.Advance(time.Hour)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -836,27 +773,21 @@ func TestQueuedPackets(t *testing.T) {
// When we fail to send the initial set of reports, incoming reports should
// not affect a newly joined group's reports from being sent.
- mgp.mu.Lock()
- mgp.makeQueuePackets = true
- g.JoinGroupLocked(addr2)
- mgp.mu.Unlock()
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr2)
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- mgp.mu.Lock()
- g.HandleReportLocked(addr2)
+ mgp.handleReport(addr2)
// Attempting to send queued reports while still unable to send reports should
// not change the host state.
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.sendQueuedReports()
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Mock being able to successfully send the report.
- mgp.mu.Lock()
- mgp.makeQueuePackets = false
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -867,9 +798,7 @@ func TestQueuedPackets(t *testing.T) {
}
// Should not have anything else to send.
- mgp.mu.Lock()
- g.SendQueuedReportsLocked()
- mgp.mu.Unlock()
+ mgp.sendQueuedReports()
clock.Advance(time.Hour)
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)