diff options
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 49 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol_test.go | 227 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 12 |
4 files changed, 154 insertions, 151 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index f85c5ff9d..f2f0e069c 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -131,17 +131,6 @@ type multicastGroupState struct { // GenericMulticastProtocolOptions holds options for the generic multicast // protocol. type GenericMulticastProtocolOptions struct { - // Enabled indicates whether the generic multicast protocol will be - // performed. - // - // When enabled, the protocol may transmit report and leave messages when - // joining and leaving multicast groups respectively, and handle incoming - // packets. - // - // When disabled, the protocol will still keep track of locally joined groups, - // it just won't transmit and handle packets, or update groups' state. - Enabled bool - // Rand is the source of random numbers. Rand *rand.Rand @@ -170,6 +159,17 @@ type GenericMulticastProtocolOptions struct { // MulticastGroupProtocol is a multicast group protocol whose core state machine // can be represented by GenericMulticastProtocolState. type MulticastGroupProtocol interface { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled() bool + // SendReport sends a multicast report for the specified group address. // // Returns false if the caller should queue the report to be sent later. Note, @@ -196,6 +196,9 @@ type MulticastGroupProtocol interface { // // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. +// +// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the +// multicast group protocol is disabled so that leave messages may be sent. type GenericMulticastProtocolState struct { // Do not allow overwriting this state. _ sync.NoCopy @@ -235,9 +238,11 @@ func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts Gene // // The groups will still be considered joined locally. // +// MUST be called when the multicast group protocol is disabled. +// // Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { - if !g.opts.Enabled { + if !g.opts.Protocol.Enabled() { return } @@ -255,7 +260,7 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { // // Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { - if !g.opts.Enabled { + if !g.opts.Protocol.Enabled() { return } @@ -290,12 +295,8 @@ func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { // JoinGroupLocked handles joining a new group. // -// If dontInitialize is true, the group will be not be initialized and will be -// left in the non-member state - no packets will be sent for it until it is -// initialized via InitializeGroups. -// // Precondition: g.protocolMU must be locked. -func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) { +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ @@ -310,6 +311,10 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre state: nonMember, lastToSendReport: false, delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + if !g.opts.Protocol.Enabled() { + panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) + } + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) @@ -320,7 +325,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre }), } - if !dontInitialize && g.opts.Enabled { + if g.opts.Protocol.Enabled() { g.initializeNewMemberLocked(groupAddress, &info) } @@ -372,7 +377,7 @@ func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Addr // // Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { - if !g.opts.Enabled { + if !g.opts.Protocol.Enabled() { return } @@ -406,7 +411,7 @@ func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Add // // Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { - if !g.opts.Enabled { + if !g.opts.Protocol.Enabled() { return } @@ -518,7 +523,7 @@ func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddres // maybeSendLeave attempts to send a leave message. func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { - if !g.opts.Enabled || !lastToSendReport { + if !g.opts.Protocol.Enabled() || !lastToSendReport { return } diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 95040515c..f56f7aa90 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -50,6 +50,9 @@ type mockMulticastGroupProtocol struct { // Must only be accessed with mu held. makeQueuePackets bool + + // Must only be accessed with mu held. + disabled bool } func (m *mockMulticastGroupProtocol) init() { @@ -63,6 +66,22 @@ func (m *mockMulticastGroupProtocol) initLocked() { m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } +func (m *mockMulticastGroupProtocol) setEnabled(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.disabled = !v +} + +// Enabled implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be read locked. +func (m *mockMulticastGroupProtocol) Enabled() bool { + return !m.disabled +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { if m.mu.TryLock() { m.mu.Unlock() @@ -77,6 +96,9 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo return !m.makeQueuePackets, nil } +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { if m.mu.TryLock() { m.mu.Unlock() @@ -115,7 +137,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t cmp.FilterPath( func(p cmp.Path) bool { - return p.Last().String() == ".mu" || p.Last().String() == ".t" || p.Last().String() == ".makeQueuePackets" + switch p.Last().String() { + case ".mu", ".t", ".makeQueuePackets", ".disabled": + return true + } + return false }, cmp.Ignore(), ), @@ -150,7 +176,6 @@ func TestJoinGroup(t *testing.T) { mgp.init() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(0)), Clock: clock, Protocol: &mgp, @@ -161,7 +186,7 @@ func TestJoinGroup(t *testing.T) { // Joining a group should send a report immediately and another after // a random interval between 0 and the maximum unsolicited report delay. mgp.mu.Lock() - g.JoinGroupLocked(test.addr, false /* dontInitialize */) + g.JoinGroupLocked(test.addr) mgp.mu.Unlock() if test.shouldSendReports { if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { @@ -210,7 +235,6 @@ func TestLeaveGroup(t *testing.T) { mgp.init() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(1)), Clock: clock, Protocol: &mgp, @@ -219,7 +243,7 @@ func TestLeaveGroup(t *testing.T) { }) mgp.mu.Lock() - g.JoinGroupLocked(test.addr, false /* dontInitialize */) + g.JoinGroupLocked(test.addr) mgp.mu.Unlock() if test.shouldSendMessages { if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { @@ -295,7 +319,6 @@ func TestHandleReport(t *testing.T) { mgp.init() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(2)), Clock: clock, Protocol: &mgp, @@ -304,19 +327,19 @@ func TestHandleReport(t *testing.T) { }) mgp.mu.Lock() - g.JoinGroupLocked(addr1, false /* dontInitialize */) + g.JoinGroupLocked(addr1) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } mgp.mu.Lock() - g.JoinGroupLocked(addr2, false /* dontInitialize */) + g.JoinGroupLocked(addr2) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } mgp.mu.Lock() - g.JoinGroupLocked(addr3, false /* dontInitialize */) + g.JoinGroupLocked(addr3) mgp.mu.Unlock() if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -391,7 +414,6 @@ func TestHandleQuery(t *testing.T) { mgp.init() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(3)), Clock: clock, Protocol: &mgp, @@ -400,19 +422,19 @@ func TestHandleQuery(t *testing.T) { }) mgp.mu.Lock() - g.JoinGroupLocked(addr1, false /* dontInitialize */) + g.JoinGroupLocked(addr1) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } mgp.mu.Lock() - g.JoinGroupLocked(addr2, false /* dontInitialize */) + g.JoinGroupLocked(addr2) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } mgp.mu.Lock() - g.JoinGroupLocked(addr3, false /* dontInitialize */) + g.JoinGroupLocked(addr3) mgp.mu.Unlock() if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -451,7 +473,6 @@ func TestJoinCount(t *testing.T) { mgp.init() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(4)), Clock: clock, Protocol: &mgp, @@ -461,7 +482,7 @@ func TestJoinCount(t *testing.T) { // Set the join count to 2 for a group. { mgp.mu.Lock() - g.JoinGroupLocked(addr1, false /* dontInitialize */) + g.JoinGroupLocked(addr1) res := g.IsLocallyJoinedRLocked(addr1) mgp.mu.Unlock() if !res { @@ -474,7 +495,7 @@ func TestJoinCount(t *testing.T) { } { mgp.mu.Lock() - g.JoinGroupLocked(addr1, false /* dontInitialize */) + g.JoinGroupLocked(addr1) res := g.IsLocallyJoinedRLocked(addr1) mgp.mu.Unlock() if !res { @@ -563,7 +584,6 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { mgp.init() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(3)), Clock: clock, Protocol: &mgp, @@ -572,19 +592,19 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { }) mgp.mu.Lock() - g.JoinGroupLocked(addr1, false /* dontInitialize */) + g.JoinGroupLocked(addr1) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } mgp.mu.Lock() - g.JoinGroupLocked(addr2, false /* dontInitialize */) + g.JoinGroupLocked(addr2) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } mgp.mu.Lock() - g.JoinGroupLocked(addr3, false /* dontInitialize */) + g.JoinGroupLocked(addr3) mgp.mu.Unlock() if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -634,105 +654,79 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { // TestGroupStateNonMember tests that groups do not send packets when in the // non-member state, but are still considered locally joined. func TestGroupStateNonMember(t *testing.T) { - tests := []struct { - name string - enabled bool - dontInitialize bool - }{ - { - name: "Disabled", - enabled: false, - dontInitialize: false, - }, - { - name: "Keep non-member", - enabled: true, - dontInitialize: true, - }, - { - name: "disabled and Keep non-member", - enabled: false, - dontInitialize: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - mgp := mockMulticastGroupProtocol{t: t} - clock := faketime.NewManualClock() + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() - mgp.init() - g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: test.enabled, - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - Protocol: &mgp, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) + mgp.init() + mgp.setEnabled(false) + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) - // Joining groups should not send any reports. - { - mgp.mu.Lock() - g.JoinGroupLocked(addr1, test.dontInitialize) - res := g.IsLocallyJoinedRLocked(addr1) - mgp.mu.Unlock() - if !res { - t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - { - mgp.mu.Lock() - g.JoinGroupLocked(addr2, test.dontInitialize) - res := g.IsLocallyJoinedRLocked(addr2) - mgp.mu.Unlock() - if !res { - t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) - } - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Joining groups should not send any reports. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + { + mgp.mu.Lock() + g.JoinGroupLocked(addr2) + res := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - // Receiving a query should not send any reports. - mgp.mu.Lock() - g.HandleQueryLocked(addr1, time.Nanosecond) - mgp.mu.Unlock() - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(time.Nanosecond) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Receiving a query should not send any reports. + mgp.mu.Lock() + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - // Leaving groups should not send any leave messages. - { - mgp.mu.Lock() - addr2LeaveRes := g.LeaveGroupLocked(addr2) - addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) - addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) - mgp.mu.Unlock() - if !addr2LeaveRes { - t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) - } - if !addr1IsJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) - } - if addr2IsJoined { - t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) - } - } - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Leaving groups should not send any leave messages. + { + mgp.mu.Lock() + addr2LeaveRes := g.LeaveGroupLocked(addr2) + addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) + addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !addr2LeaveRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) + } + if !addr1IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + if addr2IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - clock.Advance(time.Hour) - if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } @@ -742,7 +736,6 @@ func TestQueuedPackets(t *testing.T) { mgp.init() clock := faketime.NewManualClock() g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ - Enabled: true, Rand: rand.New(rand.NewSource(4)), Clock: clock, Protocol: &mgp, @@ -753,7 +746,7 @@ func TestQueuedPackets(t *testing.T) { // send the packet. mgp.mu.Lock() mgp.makeQueuePackets = true - g.JoinGroupLocked(addr1, false /* dontInitialize */) + g.JoinGroupLocked(addr1) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) @@ -845,7 +838,7 @@ func TestQueuedPackets(t *testing.T) { // not affect a newly joined group's reports from being sent. mgp.mu.Lock() mgp.makeQueuePackets = true - g.JoinGroupLocked(addr2, false /* dontInitialize */) + g.JoinGroupLocked(addr2) mgp.mu.Unlock() if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index fb7a9e68e..da88d65d1 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -72,8 +72,6 @@ type igmpState struct { // The IPv4 endpoint this igmpState is for. ep *endpoint - enabled bool - genericMulticastProtocol ip.GenericMulticastProtocolState // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from @@ -95,6 +93,13 @@ type igmpState struct { igmpV1Job *tcpip.Job } +// Enabled implements ip.MulticastGroupProtocol. +func (igmp *igmpState) Enabled() bool { + // No need to perform IGMP on loopback interfaces since they don't have + // neighbouring nodes. + return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled() +} + // SendReport implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. @@ -127,11 +132,7 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // Must only be called once for the lifetime of igmp. func (igmp *igmpState) init(ep *endpoint) { igmp.ep = ep - // No need to perform IGMP on loopback interfaces since they don't have - // neighbouring nodes. - igmp.enabled = ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ - Enabled: igmp.enabled, Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: igmp, @@ -223,7 +224,7 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero // then change the state to note that an IGMPv1 router is present and // schedule the query received Job. - if igmp.enabled && maxRespTime == 0 { + if maxRespTime == 0 && igmp.Enabled() { igmp.igmpV1Job.Cancel() igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) igmp.setV1Present(true) @@ -296,7 +297,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { - igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress) } // isInGroup returns true if the specified group has been joined locally. diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 6f64b8462..e8d1e7a79 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -58,6 +58,13 @@ type mldState struct { genericMulticastProtocol ip.GenericMulticastProtocolState } +// Enabled implements ip.MulticastGroupProtocol. +func (mld *mldState) Enabled() bool { + // No need to perform MLD on loopback interfaces since they don't have + // neighbouring nodes. + return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled() +} + // SendReport implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. @@ -80,9 +87,6 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { func (mld *mldState) init(ep *endpoint) { mld.ep = ep mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ - // No need to perform MLD on loopback interfaces since they don't have - // neighbouring nodes. - Enabled: ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback(), Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: mld, @@ -112,7 +116,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { // // Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { - mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */) + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress) } // isInGroup returns true if the specified group has been joined locally. |