summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go49
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go227
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go17
-rw-r--r--pkg/tcpip/network/ipv6/mld.go12
4 files changed, 154 insertions, 151 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index f85c5ff9d..f2f0e069c 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -131,17 +131,6 @@ type multicastGroupState struct {
// GenericMulticastProtocolOptions holds options for the generic multicast
// protocol.
type GenericMulticastProtocolOptions struct {
- // Enabled indicates whether the generic multicast protocol will be
- // performed.
- //
- // When enabled, the protocol may transmit report and leave messages when
- // joining and leaving multicast groups respectively, and handle incoming
- // packets.
- //
- // When disabled, the protocol will still keep track of locally joined groups,
- // it just won't transmit and handle packets, or update groups' state.
- Enabled bool
-
// Rand is the source of random numbers.
Rand *rand.Rand
@@ -170,6 +159,17 @@ type GenericMulticastProtocolOptions struct {
// MulticastGroupProtocol is a multicast group protocol whose core state machine
// can be represented by GenericMulticastProtocolState.
type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
// SendReport sends a multicast report for the specified group address.
//
// Returns false if the caller should queue the report to be sent later. Note,
@@ -196,6 +196,9 @@ type MulticastGroupProtocol interface {
//
// GenericMulticastProtocolState.Init MUST be called before calling any of
// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
type GenericMulticastProtocolState struct {
// Do not allow overwriting this state.
_ sync.NoCopy
@@ -235,9 +238,11 @@ func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts Gene
//
// The groups will still be considered joined locally.
//
+// MUST be called when the multicast group protocol is disabled.
+//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -255,7 +260,7 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -290,12 +295,8 @@ func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
// JoinGroupLocked handles joining a new group.
//
-// If dontInitialize is true, the group will be not be initialized and will be
-// left in the non-member state - no packets will be sent for it until it is
-// initialized via InitializeGroups.
-//
// Precondition: g.protocolMU must be locked.
-func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) {
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
if info, ok := g.memberships[groupAddress]; ok {
// The group has already been joined.
info.joins++
@@ -310,6 +311,10 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
state: nonMember,
lastToSendReport: false,
delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
info, ok := g.memberships[groupAddress]
if !ok {
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
@@ -320,7 +325,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
}),
}
- if !dontInitialize && g.opts.Enabled {
+ if g.opts.Protocol.Enabled() {
g.initializeNewMemberLocked(groupAddress, &info)
}
@@ -372,7 +377,7 @@ func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Addr
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -406,7 +411,7 @@ func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Add
//
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
- if !g.opts.Enabled {
+ if !g.opts.Protocol.Enabled() {
return
}
@@ -518,7 +523,7 @@ func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddres
// maybeSendLeave attempts to send a leave message.
func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
- if !g.opts.Enabled || !lastToSendReport {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
return
}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 95040515c..f56f7aa90 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -50,6 +50,9 @@ type mockMulticastGroupProtocol struct {
// Must only be accessed with mu held.
makeQueuePackets bool
+
+ // Must only be accessed with mu held.
+ disabled bool
}
func (m *mockMulticastGroupProtocol) init() {
@@ -63,6 +66,22 @@ func (m *mockMulticastGroupProtocol) initLocked() {
m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
}
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.disabled = !v
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ return !m.disabled
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
if m.mu.TryLock() {
m.mu.Unlock()
@@ -77,6 +96,9 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
return !m.makeQueuePackets, nil
}
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
if m.mu.TryLock() {
m.mu.Unlock()
@@ -115,7 +137,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr
// ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
cmp.FilterPath(
func(p cmp.Path) bool {
- return p.Last().String() == ".mu" || p.Last().String() == ".t" || p.Last().String() == ".makeQueuePackets"
+ switch p.Last().String() {
+ case ".mu", ".t", ".makeQueuePackets", ".disabled":
+ return true
+ }
+ return false
},
cmp.Ignore(),
),
@@ -150,7 +176,6 @@ func TestJoinGroup(t *testing.T) {
mgp.init()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
Protocol: &mgp,
@@ -161,7 +186,7 @@ func TestJoinGroup(t *testing.T) {
// Joining a group should send a report immediately and another after
// a random interval between 0 and the maximum unsolicited report delay.
mgp.mu.Lock()
- g.JoinGroupLocked(test.addr, false /* dontInitialize */)
+ g.JoinGroupLocked(test.addr)
mgp.mu.Unlock()
if test.shouldSendReports {
if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
@@ -210,7 +235,6 @@ func TestLeaveGroup(t *testing.T) {
mgp.init()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
Protocol: &mgp,
@@ -219,7 +243,7 @@ func TestLeaveGroup(t *testing.T) {
})
mgp.mu.Lock()
- g.JoinGroupLocked(test.addr, false /* dontInitialize */)
+ g.JoinGroupLocked(test.addr)
mgp.mu.Unlock()
if test.shouldSendMessages {
if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
@@ -295,7 +319,6 @@ func TestHandleReport(t *testing.T) {
mgp.init()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
Protocol: &mgp,
@@ -304,19 +327,19 @@ func TestHandleReport(t *testing.T) {
})
mgp.mu.Lock()
- g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ g.JoinGroupLocked(addr1)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.mu.Lock()
- g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ g.JoinGroupLocked(addr2)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.mu.Lock()
- g.JoinGroupLocked(addr3, false /* dontInitialize */)
+ g.JoinGroupLocked(addr3)
mgp.mu.Unlock()
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -391,7 +414,6 @@ func TestHandleQuery(t *testing.T) {
mgp.init()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
Protocol: &mgp,
@@ -400,19 +422,19 @@ func TestHandleQuery(t *testing.T) {
})
mgp.mu.Lock()
- g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ g.JoinGroupLocked(addr1)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.mu.Lock()
- g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ g.JoinGroupLocked(addr2)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.mu.Lock()
- g.JoinGroupLocked(addr3, false /* dontInitialize */)
+ g.JoinGroupLocked(addr3)
mgp.mu.Unlock()
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -451,7 +473,6 @@ func TestJoinCount(t *testing.T) {
mgp.init()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(4)),
Clock: clock,
Protocol: &mgp,
@@ -461,7 +482,7 @@ func TestJoinCount(t *testing.T) {
// Set the join count to 2 for a group.
{
mgp.mu.Lock()
- g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ g.JoinGroupLocked(addr1)
res := g.IsLocallyJoinedRLocked(addr1)
mgp.mu.Unlock()
if !res {
@@ -474,7 +495,7 @@ func TestJoinCount(t *testing.T) {
}
{
mgp.mu.Lock()
- g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ g.JoinGroupLocked(addr1)
res := g.IsLocallyJoinedRLocked(addr1)
mgp.mu.Unlock()
if !res {
@@ -563,7 +584,6 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
mgp.init()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
Protocol: &mgp,
@@ -572,19 +592,19 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
})
mgp.mu.Lock()
- g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ g.JoinGroupLocked(addr1)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.mu.Lock()
- g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ g.JoinGroupLocked(addr2)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
mgp.mu.Lock()
- g.JoinGroupLocked(addr3, false /* dontInitialize */)
+ g.JoinGroupLocked(addr3)
mgp.mu.Unlock()
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -634,105 +654,79 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
// TestGroupStateNonMember tests that groups do not send packets when in the
// non-member state, but are still considered locally joined.
func TestGroupStateNonMember(t *testing.T) {
- tests := []struct {
- name string
- enabled bool
- dontInitialize bool
- }{
- {
- name: "Disabled",
- enabled: false,
- dontInitialize: false,
- },
- {
- name: "Keep non-member",
- enabled: true,
- dontInitialize: true,
- },
- {
- name: "disabled and Keep non-member",
- enabled: false,
- dontInitialize: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- mgp := mockMulticastGroupProtocol{t: t}
- clock := faketime.NewManualClock()
+ var g ip.GenericMulticastProtocolState
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
- mgp.init()
- g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: test.enabled,
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- Protocol: &mgp,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
+ mgp.init()
+ mgp.setEnabled(false)
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
- // Joining groups should not send any reports.
- {
- mgp.mu.Lock()
- g.JoinGroupLocked(addr1, test.dontInitialize)
- res := g.IsLocallyJoinedRLocked(addr1)
- mgp.mu.Unlock()
- if !res {
- t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- {
- mgp.mu.Lock()
- g.JoinGroupLocked(addr2, test.dontInitialize)
- res := g.IsLocallyJoinedRLocked(addr2)
- mgp.mu.Unlock()
- if !res {
- t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2)
- }
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Joining groups should not send any reports.
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2)
+ res := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- // Receiving a query should not send any reports.
- mgp.mu.Lock()
- g.HandleQueryLocked(addr1, time.Nanosecond)
- mgp.mu.Unlock()
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Nanosecond)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Receiving a query should not send any reports.
+ mgp.mu.Lock()
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- // Leaving groups should not send any leave messages.
- {
- mgp.mu.Lock()
- addr2LeaveRes := g.LeaveGroupLocked(addr2)
- addr1IsJoined := g.IsLocallyJoinedRLocked(addr1)
- addr2IsJoined := g.IsLocallyJoinedRLocked(addr2)
- mgp.mu.Unlock()
- if !addr2LeaveRes {
- t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2)
- }
- if !addr1IsJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
- }
- if addr2IsJoined {
- t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2)
- }
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Leaving groups should not send any leave messages.
+ {
+ mgp.mu.Lock()
+ addr2LeaveRes := g.LeaveGroupLocked(addr2)
+ addr1IsJoined := g.IsLocallyJoinedRLocked(addr1)
+ addr2IsJoined := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !addr2LeaveRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2)
+ }
+ if !addr1IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ if addr2IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
@@ -742,7 +736,6 @@ func TestQueuedPackets(t *testing.T) {
mgp.init()
clock := faketime.NewManualClock()
g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
- Enabled: true,
Rand: rand.New(rand.NewSource(4)),
Clock: clock,
Protocol: &mgp,
@@ -753,7 +746,7 @@ func TestQueuedPackets(t *testing.T) {
// send the packet.
mgp.mu.Lock()
mgp.makeQueuePackets = true
- g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ g.JoinGroupLocked(addr1)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
@@ -845,7 +838,7 @@ func TestQueuedPackets(t *testing.T) {
// not affect a newly joined group's reports from being sent.
mgp.mu.Lock()
mgp.makeQueuePackets = true
- g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ g.JoinGroupLocked(addr2)
mgp.mu.Unlock()
if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index fb7a9e68e..da88d65d1 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -72,8 +72,6 @@ type igmpState struct {
// The IPv4 endpoint this igmpState is for.
ep *endpoint
- enabled bool
-
genericMulticastProtocol ip.GenericMulticastProtocolState
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
@@ -95,6 +93,13 @@ type igmpState struct {
igmpV1Job *tcpip.Job
}
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
+}
+
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
@@ -127,11 +132,7 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// Must only be called once for the lifetime of igmp.
func (igmp *igmpState) init(ep *endpoint) {
igmp.ep = ep
- // No need to perform IGMP on loopback interfaces since they don't have
- // neighbouring nodes.
- igmp.enabled = ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback()
igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
- Enabled: igmp.enabled,
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
@@ -223,7 +224,7 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
- if igmp.enabled && maxRespTime == 0 {
+ if maxRespTime == 0 && igmp.Enabled() {
igmp.igmpV1Job.Cancel()
igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)
@@ -296,7 +297,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
- igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index 6f64b8462..e8d1e7a79 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -58,6 +58,13 @@ type mldState struct {
genericMulticastProtocol ip.GenericMulticastProtocolState
}
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
@@ -80,9 +87,6 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
func (mld *mldState) init(ep *endpoint) {
mld.ep = ep
mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
- // No need to perform MLD on loopback interfaces since they don't have
- // neighbouring nodes.
- Enabled: ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback(),
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
@@ -112,7 +116,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
- mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.