summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-12-01 07:50:49 -0800
committergVisor bot <gvisor-bot@google.com>2020-12-01 07:52:40 -0800
commit25570ac4f3a0c4b51251b8111d64d5c1dc6b9a67 (patch)
treeeaa642fa6324fe7dfbb1c6a417170a88d16789e7 /pkg
parent6b1dbbbdc8ec0b5d04e0961a216bd7323dbc45fb (diff)
Track join count in multicast group protocol state
Before this change, the join count and the state for IGMP/MLD was held across different types which required multiple locks to be held when accessing a multicast group's state. Bug #4682, #4861 Fixes #4916 PiperOrigin-RevId: 345019091
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go398
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go456
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go82
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go69
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go76
-rw-r--r--pkg/tcpip/network/ipv6/mld.go75
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go61
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go50
-rw-r--r--pkg/tcpip/stack/nic.go9
-rw-r--r--pkg/tcpip/stack/registration.go8
10 files changed, 791 insertions, 493 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index 3113f4bbe..f14e2a88a 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -31,12 +31,15 @@ type hostState int
// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
// and MLDv1, IGMPv2 terminology will be used.
const (
- // "'Non-Member' state, when the host does not belong to the group on
- // the interface. This is the initial state for all memberships on
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
// all network interfaces; it requires no storage in the host."
//
// 'Non-Listener' is the MLDv1 term used to describe this state.
- _ hostState = iota
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
// delayingMember is the "'Delaying Member' state, when the host belongs to
// the group on the interface and has a report delay timer running for that
@@ -56,7 +59,10 @@ const (
// multicastGroupState holds the Generic Multicast Protocol state for a
// multicast group.
type multicastGroupState struct {
- // state contains the host's state for the group.
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
state hostState
// lastToSendReport is true if we sent the last report for the group. It is
@@ -75,6 +81,45 @@ type multicastGroupState struct {
delayedReportJob *tcpip.Job
}
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled bool
+
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
// MulticastGroupProtocol is a multicast group protocol whose core state machine
// can be represented by GenericMulticastProtocolState.
type MulticastGroupProtocol interface {
@@ -96,13 +141,10 @@ type MulticastGroupProtocol interface {
// GenericMulticastProtocolState.Init MUST be called before calling any of
// the methods on GenericMulticastProtocolState.
type GenericMulticastProtocolState struct {
- rand *rand.Rand
- clock tcpip.Clock
- protocol MulticastGroupProtocol
- maxUnsolicitedReportDelay time.Duration
+ opts GenericMulticastProtocolOptions
mu struct {
- sync.Mutex
+ sync.RWMutex
// memberships holds group addresses and their associated state.
memberships map[tcpip.Address]multicastGroupState
@@ -113,130 +155,123 @@ type GenericMulticastProtocolState struct {
//
// maxUnsolicitedReportDelay is the maximum time between sending unsolicited
// reports after joining a group.
-func (g *GenericMulticastProtocolState) Init(rand *rand.Rand, clock tcpip.Clock, protocol MulticastGroupProtocol, maxUnsolicitedReportDelay time.Duration) {
+func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) {
g.mu.Lock()
defer g.mu.Unlock()
- g.rand = rand
- g.clock = clock
- g.protocol = protocol
- g.maxUnsolicitedReportDelay = maxUnsolicitedReportDelay
+ g.opts = opts
g.mu.memberships = make(map[tcpip.Address]multicastGroupState)
}
+// MakeAllNonMember transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+func (g *GenericMulticastProtocolState) MakeAllNonMember() {
+ if !g.opts.Enabled {
+ return
+ }
+
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ for groupAddress, info := range g.mu.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.mu.memberships[groupAddress] = info
+ }
+}
+
+// InitializeGroups initializes each group, as if they were newly joined but
+// without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
+func (g *GenericMulticastProtocolState) InitializeGroups() {
+ if !g.opts.Enabled {
+ return
+ }
+
+ g.mu.Lock()
+ defer g.mu.Unlock()
+
+ for groupAddress, info := range g.mu.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.mu.memberships[groupAddress] = info
+ }
+}
+
// JoinGroup handles joining a new group.
//
-// Returns false if the group has already been joined.
-func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address) bool {
+// If dontInitialize is true, the group will be not be initialized and will be
+// left in the non-member state - no packets will be sent for it until it is
+// initialized via InitializeGroups.
+func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) {
g.mu.Lock()
defer g.mu.Unlock()
- if _, ok := g.mu.memberships[groupAddress]; ok {
+ if info, ok := g.mu.memberships[groupAddress]; ok {
// The group has already been joined.
- return false
+ info.joins++
+ g.mu.memberships[groupAddress] = info
+ return
}
info := multicastGroupState{
- // There isn't a job scheduled currently, so it's just idle.
- state: idleMember,
- // Joining a group immediately sends a report.
- lastToSendReport: true,
- delayedReportJob: tcpip.NewJob(g.clock, &g.mu, func() {
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() {
info, ok := g.mu.memberships[groupAddress]
if !ok {
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
}
- info.lastToSendReport = g.protocol.SendReport(groupAddress) == nil
+ info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
info.state = idleMember
g.mu.memberships[groupAddress] = info
}),
}
- // As per RFC 2236 section 3 page 5 (for IGMPv2),
- //
- // When a host joins a multicast group, it should immediately transmit an
- // unsolicited Version 2 Membership Report for that group" ... "it is
- // recommended that it be repeated".
- //
- // As per RFC 2710 section 4 page 6 (for MLDv1),
- //
- // When a node starts listening to a multicast address on an interface,
- // it should immediately transmit an unsolicited Report for that address
- // on that interface, in case it is the first listener on the link. To
- // cover the possibility of the initial Report being lost or damaged, it
- // is recommended that it be repeated once or twice after short delays
- // [Unsolicited Report Interval].
- //
- // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
- // unsolicited reports.
- info.lastToSendReport = g.protocol.SendReport(groupAddress) == nil
- g.setDelayTimerForAddressRLocked(groupAddress, &info, g.maxUnsolicitedReportDelay)
+ if !dontInitialize && g.opts.Enabled {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
g.mu.memberships[groupAddress] = info
- return true
+}
+
+// IsLocallyJoined returns true if the group is locally joined.
+func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool {
+ g.mu.RLock()
+ defer g.mu.RUnlock()
+ _, ok := g.mu.memberships[groupAddress]
+ return ok
}
// LeaveGroup handles leaving the group.
-func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) {
+//
+// Returns false if the group is not currently joined.
+func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool {
g.mu.Lock()
defer g.mu.Unlock()
info, ok := g.mu.memberships[groupAddress]
if !ok {
- return
+ return false
}
- info.delayedReportJob.Cancel()
- delete(g.mu.memberships, groupAddress)
- if info.lastToSendReport {
- // Okay to ignore the error here as if packet write failed, the multicast
- // routers will eventually drop our membership anyways. If the interface is
- // being disabled or removed, the generic multicast protocol's should be
- // cleared eventually.
- //
- // As per RFC 2236 section 3 page 5 (for IGMPv2),
- //
- // When a router receives a Report, it adds the group being reported to
- // the list of multicast group memberships on the network on which it
- // received the Report and sets the timer for the membership to the
- // [Group Membership Interval]. Repeated Reports refresh the timer. If
- // no Reports are received for a particular group before this timer has
- // expired, the router assumes that the group has no local members and
- // that it need not forward remotely-originated multicasts for that
- // group onto the attached network.
- //
- // As per RFC 2710 section 4 page 5 (for MLDv1),
- //
- // When a router receives a Report from a link, if the reported address
- // is not already present in the router's list of multicast address
- // having listeners on that link, the reported address is added to the
- // list, its timer is set to [Multicast Listener Interval], and its
- // appearance is made known to the router's multicast routing component.
- // If a Report is received for a multicast address that is already
- // present in the router's list, the timer for that address is reset to
- // [Multicast Listener Interval]. If an address's timer expires, it is
- // assumed that there are no longer any listeners for that address
- // present on the link, so it is deleted from the list and its
- // disappearance is made known to the multicast routing component.
- //
- // The requirement to send a leave message is also optional (it MAY be
- // skipped):
- //
- // As per RFC 2236 section 6 page 8 (for IGMPv2),
- //
- // "send leave" for the group on the interface. If the interface
- // state says the Querier is running IGMPv1, this action SHOULD be
- // skipped. If the flag saying we were the last host to report is
- // cleared, this action MAY be skipped. The Leave Message is sent to
- // the ALL-ROUTERS group (224.0.0.2).
- //
- // As per RFC 2710 section 5 page 8 (for MLDv1),
- //
- // "send done" for the address on the interface. If the flag saying
- // we were the last node to report is cleared, this action MAY be
- // skipped. The Done message is sent to the link-scope all-routers
- // address (FF02::2).
- _ = g.protocol.SendLeave(groupAddress)
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.mu.memberships[groupAddress] = info
+ return true
}
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.mu.memberships, groupAddress)
+ return true
}
// HandleQuery handles a query message with the specified maximum response time.
@@ -247,6 +282,10 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) {
// Report(s) will be scheduled to be sent after a random duration between 0 and
// the maximum response time.
func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Enabled {
+ return
+ }
+
g.mu.Lock()
defer g.mu.Unlock()
@@ -278,6 +317,10 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address,
// If the report is for a joined group, any active delayed report will be
// cancelled and the host state for the group transitions to idle.
func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) {
+ if !g.opts.Enabled {
+ return
+ }
+
g.mu.Lock()
defer g.mu.Unlock()
@@ -293,7 +336,7 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address)
// multicast address while it has a timer running for that same address
// on that interface, it stops its timer and does not send a Report for
// that address, thus suppressing duplicate reports on the link.
- if info, ok := g.mu.memberships[groupAddress]; ok {
+ if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember {
info.delayedReportJob.Cancel()
info.lastToSendReport = false
info.state = idleMember
@@ -301,10 +344,167 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address)
}
}
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.mu must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state))
+ }
+
+ info.state = idleMember
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+}
+
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Enabled || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: e.mu must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.delayedReportJob.Cancel()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
// setDelayTimerForAddressRLocked sets timer to send a delay report.
//
// Precondition: g.mu MUST be read locked.
func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
// As per RFC 2236 section 3 page 3 (for IGMPv2),
//
// If a timer for the group is already unning, it is reset to the random
@@ -342,5 +542,5 @@ func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime
if maxRespTime == 0 {
return 0
}
- return time.Duration(g.rand.Int63n(int64(maxRespTime)))
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index eb48c0d51..670be30d4 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -29,18 +29,21 @@ const (
addr1 = tcpip.Address("\x01")
addr2 = tcpip.Address("\x02")
addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+
+ maxUnsolicitedReportDelay = time.Second
)
var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
type mockMulticastGroupProtocol struct {
sendReportGroupAddrCount map[tcpip.Address]int
- sendLeaveGroupAddr tcpip.Address
+ sendLeaveGroupAddrCount map[tcpip.Address]int
}
func (m *mockMulticastGroupProtocol) init() {
m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
- m.sendLeaveGroupAddr = ""
+ m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
}
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error {
@@ -49,87 +52,146 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcp
}
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
- m.sendLeaveGroupAddr = groupAddress
+ m.sendLeaveGroupAddrCount[groupAddress]++
return nil
}
-func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddr tcpip.Address) string {
+func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
sendReportGroupAddressesMap := make(map[tcpip.Address]int)
for _, a := range sendReportGroupAddresses {
sendReportGroupAddressesMap[a] = 1
}
+ sendLeaveGroupAddressesMap := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddressesMap[a] = 1
+ }
+
diff := cmp.Diff(mockMulticastGroupProtocol{
sendReportGroupAddrCount: sendReportGroupAddressesMap,
- sendLeaveGroupAddr: sendLeaveGroupAddr,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap,
}, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{}))
mgp.init()
return diff
}
func TestJoinGroup(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
-
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
- clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(0)), clock, &mgp, maxUnsolicitedReportDelay)
-
- // Joining a group should send a report immediately and another after
- // a random interval between 0 and the maximum unsolicited report delay.
- if !g.JoinGroup(addr1) {
- t.Errorf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
}
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ g.JoinGroup(test.addr, false /* dontInitialize */)
+ if test.shouldSendReports {
+ if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
func TestLeaveGroup(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
-
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
- clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(1)), clock, &mgp, maxUnsolicitedReportDelay)
-
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
}
- // Leaving a group should send a leave report immediately and cancel any
- // delayed reports.
- g.LeaveGroup(addr1)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, addr1 /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ g.JoinGroup(test.addr, false /* dontInitialize */)
+ if test.shouldSendMessages {
+ if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ if !g.LeaveGroup(test.addr) {
+ t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr)
+ }
+ if test.shouldSendMessages {
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
func TestHandleReport(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
-
tests := []struct {
name string
reportAddr tcpip.Address
@@ -151,10 +213,15 @@ func TestHandleReport(t *testing.T) {
expectReportsFor: []tcpip.Address{addr2},
},
{
- name: "Specified other",
+ name: "Specified all-nodes",
reportAddr: addr3,
expectReportsFor: []tcpip.Address{addr1, addr2},
},
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
}
for _, test := range tests {
@@ -163,18 +230,25 @@ func TestHandleReport(t *testing.T) {
var mgp mockMulticastGroupProtocol
mgp.init()
clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(2)), clock, &mgp, maxUnsolicitedReportDelay)
-
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if !g.JoinGroup(addr2) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr2)
+ g.JoinGroup(addr2, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ g.JoinGroup(addr3, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -183,14 +257,14 @@ func TestHandleReport(t *testing.T) {
g.HandleReport(test.reportAddr)
if len(test.expectReportsFor) != 0 {
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -198,8 +272,6 @@ func TestHandleReport(t *testing.T) {
}
func TestHandleQuery(t *testing.T) {
- const maxUnsolicitedReportDelay = time.Second
-
tests := []struct {
name string
queryAddr tcpip.Address
@@ -225,11 +297,17 @@ func TestHandleQuery(t *testing.T) {
expectReportsFor: []tcpip.Address{addr1},
},
{
- name: "Specified other",
+ name: "Specified all-nodes",
queryAddr: addr3,
maxDelay: 3,
expectReportsFor: nil,
},
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectReportsFor: nil,
+ },
}
for _, test := range tests {
@@ -238,22 +316,29 @@ func TestHandleQuery(t *testing.T) {
var mgp mockMulticastGroupProtocol
mgp.init()
clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(3)), clock, &mgp, maxUnsolicitedReportDelay)
-
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if !g.JoinGroup(addr2) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr2)
+ g.JoinGroup(addr2, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ g.JoinGroup(addr3, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
@@ -262,33 +347,230 @@ func TestHandleQuery(t *testing.T) {
g.HandleQuery(test.queryAddr, test.maxDelay)
if len(test.expectReportsFor) != 0 {
clock.Advance(test.maxDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" {
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
}
}
-func TestDoubleJoinGroup(t *testing.T) {
+func TestJoinCount(t *testing.T) {
var g ip.GenericMulticastProtocolState
var mgp mockMulticastGroupProtocol
mgp.init()
clock := faketime.NewManualClock()
- g.Init(rand.New(rand.NewSource(4)), clock, &mgp, time.Second)
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Group should still be considered joined after leaving once.
+ if !g.LeaveGroup(addr1) {
+ t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ }
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving once more should actually remove us from the group.
+ if !g.LeaveGroup(addr1) {
+ t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ }
+ if g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ if g.LeaveGroup(addr1) {
+ t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1)
+ }
+ if g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ g.JoinGroup(addr1, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr2, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ g.JoinGroup(addr3, false /* dontInitialize */)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ g.MakeAllNonMember()
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ if !g.IsLocallyJoined(group) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group)
+ }
+ }
- if !g.JoinGroup(addr1) {
- t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1)
+ // Should send the initial set of unsolcited reports.
+ g.InitializeGroups()
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ tests := []struct {
+ name string
+ enabled bool
+ dontInitialize bool
+ }{
+ {
+ name: "Disabled",
+ enabled: false,
+ dontInitialize: false,
+ },
+ {
+ name: "Keep non-member",
+ enabled: true,
+ dontInitialize: true,
+ },
+ {
+ name: "disabled and Keep non-member",
+ enabled: false,
+ dontInitialize: true,
+ },
}
- // Joining the same group twice should fail.
- if g.JoinGroup(addr1) {
- t.Errorf("got g.JoinGroup(%s) = true, want = false", addr1)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: test.enabled,
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ g.JoinGroup(addr1, test.dontInitialize)
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ g.JoinGroup(addr2, test.dontInitialize)
+ if !g.IsLocallyJoined(addr2) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ g.HandleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ if !g.LeaveGroup(addr2) {
+ t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2)
+ }
+ if !g.IsLocallyJoined(addr1) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if g.IsLocallyJoined(addr2) {
+ t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2)
+ }
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index 37f1822ca..18ccd28c3 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -124,7 +124,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
defer igmp.mu.Unlock()
igmp.ep = ep
igmp.opts = opts
- igmp.mu.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), igmp, UnsolicitedReportIntervalMax)
+ igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: opts.Enabled,
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: igmp,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv4AllSystems,
+ })
igmp.igmpV1Present = igmpV1PresentDefault
igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() {
igmp.setV1Present(false)
@@ -201,17 +208,13 @@ func (igmp *igmpState) setV1Present(v bool) {
}
func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
- if !igmp.opts.Enabled {
- return
- }
-
igmp.mu.Lock()
defer igmp.mu.Unlock()
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
- if maxRespTime == 0 {
+ if maxRespTime == 0 && igmp.opts.Enabled {
igmp.mu.igmpV1Job.Cancel()
igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)
@@ -222,10 +225,6 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp
}
func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
- if !igmp.opts.Enabled {
- return
- }
-
igmp.mu.Lock()
defer igmp.mu.Unlock()
igmp.mu.genericMulticastProtocol.HandleReport(groupAddress)
@@ -279,49 +278,46 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
//
// If the group already exists in the membership map, returns
// tcpip.ErrDuplicateAddress.
-func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error {
- if !igmp.opts.Enabled {
- return nil
- }
-
- // As per RFC 2236 section 6 page 10,
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // This is equivalent to not performing IGMP for the all-systems multicast
- // address. Simply not performing IGMP when the group is added will prevent
- // any work from being done on the all-systems multicast group when leaving
- // the group or when query or report messages are received for it since the
- // MGP state will not know about it.
- if groupAddress == header.IPv4AllSystems {
- return nil
- }
-
+func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
igmp.mu.Lock()
defer igmp.mu.Unlock()
+ igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
+}
- // JoinGroup returns false if we have already joined the group.
- if !igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress) {
- return tcpip.ErrDuplicateAddress
- }
- return nil
+// isInGroup returns true if the specified group has been joined locally.
+func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Leave Group message
// if required.
-//
-// If the group does not exist in the membership map, this function will
-// silently return.
-func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) {
- if !igmp.opts.Enabled {
- return
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+
+ // LeaveGroup returns false only if the group was not joined.
+ if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ return nil
}
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of IGMP, but remains
+// joined locally.
+func (igmp *igmpState) softLeaveAll() {
+ igmp.mu.Lock()
+ defer igmp.mu.Unlock()
+ igmp.mu.genericMulticastProtocol.MakeAllNonMember()
+}
+
+// initializeAll attemps to initialize the IGMP state for each group that has
+// been joined locally.
+func (igmp *igmpState) initializeAll() {
igmp.mu.Lock()
defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress)
+ igmp.mu.genericMulticastProtocol.InitializeGroups()
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index ce2087002..354ac1e1d 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -127,16 +127,18 @@ func (e *endpoint) Enable() *tcpip.Error {
// endpoint may have left groups from the perspective of IGMP when the
// endpoint was disabled. Either way, we need to let routers know to
// send us multicast traffic.
- joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
- for _, group := range joinedGroups {
- e.igmp.joinGroup(group)
- }
+ e.igmp.initializeAll()
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.joinGroupLocked(header.IPv4AllSystems)
- return err
+ if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err))
+ }
+
+ return nil
}
// Enabled implements stack.NetworkEndpoint.
@@ -173,16 +175,13 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- if _, err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
// Leave groups from the perspective of IGMP so that routers know that
// we are no longer interested in the group.
- joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
- for _, group := range joinedGroups {
- e.igmp.leaveGroup(group)
- }
+ e.igmp.softLeaveAll()
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
@@ -849,69 +848,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
}
-// joinGroupLocked is like JoinGroup, but with locking requirements.
+// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
- }
- // TODO(gvisor.dev/issue/4916): Keep track of join count and IGMP state in a
- // single type.
- joined, err := e.mu.addressableEndpointState.JoinGroup(addr)
- if err != nil || !joined {
- return joined, err
- }
-
- // Only join the group from the perspective of IGMP when the endpoint is
- // enabled.
- //
- // If we are not enabled right now, we will join the group from the
- // perspective of IGMP when the endpoint is enabled.
- if !e.Enabled() {
- return true, nil
- }
-
- // joinGroup only returns an error if we try to join a group twice, but we
- // checked above to make sure that the group was newly joined.
- if err := e.igmp.joinGroup(addr); err != nil {
- panic(fmt.Sprintf("e.igmp.joinGroup(%s): %s", addr, err))
+ return tcpip.ErrBadAddress
}
- return true, nil
+ e.igmp.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
}
-// leaveGroupLocked is like LeaveGroup, but with locking requirements.
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
- left, err := e.mu.addressableEndpointState.LeaveGroup(addr)
- if err != nil || !left {
- return left, err
- }
-
- e.igmp.leaveGroup(addr)
- return true, nil
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 4d49afcbb..084c38455 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -232,10 +232,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// endpoint may have left groups from the perspective of MLD when the
// endpoint was disabled. Either way, we need to let routers know to
// send us multicast traffic.
- joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
- for _, group := range joinedGroups {
- e.mld.joinGroup(group)
- }
+ e.mld.initializeAll()
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
@@ -254,8 +251,10 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
- return err
+ if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err))
}
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
@@ -344,16 +343,13 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
// Leave groups from the perspective of MLD so that routers know that
// we are no longer interested in the group.
- joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
- for _, group := range joinedGroups {
- e.mld.leaveGroup(group)
- }
+ e.mld.softLeaveAll()
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -1182,8 +1178,10 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.joinGroupLocked(snmc); err != nil {
- return nil, err
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
}
addressEndpoint.SetKind(stack.PermanentTentative)
@@ -1239,7 +1237,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ // The endpoint may have already left the multicast group.
+ if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1404,70 +1403,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
}
-// joinGroupLocked is like JoinGroup, but with locking requirements.
+// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
- }
-
- // TODO(gvisor.dev/issue/4916): Keep track of join count and MLD state in a
- // single type.
- joined, err := e.mu.addressableEndpointState.JoinGroup(addr)
- if err != nil || !joined {
- return joined, err
- }
-
- // Only join the group from the perspective of IGMP when the endpoint is
- // enabled.
- //
- // If we are not enabled right now, we will join the group from the
- // perspective of MLD when the endpoint is enabled.
- if !e.Enabled() {
- return true, nil
- }
-
- // joinGroup only returns an error if we try to join a group twice, but we
- // checked above to make sure that the group was newly joined.
- if err := e.mld.joinGroup(addr); err != nil {
- panic(fmt.Sprintf("e.mld.joinGroup(%s): %s", addr, err))
+ return tcpip.ErrBadAddress
}
- return true, nil
+ e.mld.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
}
-// leaveGroupLocked is like LeaveGroup, but with locking requirements.
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
- left, err := e.mu.addressableEndpointState.LeaveGroup(addr)
- if err != nil || !left {
- return left, err
- }
-
- e.mld.leaveGroup(addr)
- return true, nil
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index b16a1afb0..560c5e01e 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -50,8 +50,7 @@ var _ ip.MulticastGroupProtocol = (*mldState)(nil)
// mldState.init MUST be called to initialize the MLD state.
type mldState struct {
// The IPv6 endpoint this mldState is for.
- ep *endpoint
- opts MLDOptions
+ ep *endpoint
genericMulticastProtocol ip.GenericMulticastProtocolState
}
@@ -70,23 +69,21 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// a new mldState.
func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
mld.ep = ep
- mld.opts = opts
- mld.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), mld, UnsolicitedReportIntervalMax)
+ mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
+ Enabled: opts.Enabled,
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: mld,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv6AllNodesMulticastAddress,
+ })
}
func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
- if !mld.opts.Enabled {
- return
- }
-
mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
}
func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
- if !mld.opts.Enabled {
- return
- }
-
mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress())
}
@@ -94,45 +91,37 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
// messages.
//
// If the group is already joined, returns tcpip.ErrDuplicateAddress.
-func (mld *mldState) joinGroup(groupAddress tcpip.Address) *tcpip.Error {
- if !mld.opts.Enabled {
- return nil
- }
-
- // As per RFC 2710 section 5 page 10,
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
- //
- // This is equivalent to not performing MLD for the all-nodes multicast
- // address. Simply not performing MLD when the group is added will prevent
- // any work from being done on the all-nodes multicast group when leaving the
- // group or when query or report messages are received for it since the MGP
- // state will not know about it.
- if groupAddress == header.IPv6AllNodesMulticastAddress {
- return nil
- }
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
+ mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
+}
- // JoinGroup returns false if we have already joined the group.
- if !mld.genericMulticastProtocol.JoinGroup(groupAddress) {
- return tcpip.ErrDuplicateAddress
- }
- return nil
+// isInGroup returns true if the specified group has been joined locally.
+func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
+ return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Done message, if
// required.
-//
-// If the group is not joined, this function will do nothing.
-func (mld *mldState) leaveGroup(groupAddress tcpip.Address) {
- if !mld.opts.Enabled {
- return
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if mld.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ return nil
}
- mld.genericMulticastProtocol.LeaveGroup(groupAddress)
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of MLD, but remains
+// joined locally.
+func (mld *mldState) softLeaveAll() {
+ mld.genericMulticastProtocol.MakeAllNonMember()
+}
+
+// initializeAll attemps to initialize the MLD state for each group that has
+// been joined locally.
+func (mld *mldState) initializeAll() {
+ mld.genericMulticastProtocol.InitializeGroups()
}
func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error {
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 6e855d815..b438e3acc 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,7 +48,6 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
}
// ReadOnlyAddressableEndpointState provides read-only access to an
@@ -335,11 +329,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -588,61 +577,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-// JoinedGroups returns a list of groups the endpoint is a member of.
-func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address {
- a.mu.RLock()
- defer a.mu.RUnlock()
- groups := make([]tcpip.Address, 0, len(a.mu.groups))
- for g := range a.mu.groups {
- groups = append(groups, g)
- }
- return groups
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 0c8040c67..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -15,40 +15,12 @@
package stack_test
import (
- "sort"
"testing"
- "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-func TestJoinedGroups(t *testing.T) {
- const addr1 = tcpip.Address("\x01")
- const addr2 = tcpip.Address("\x02")
-
- var ep fakeNetworkEndpoint
- var s stack.AddressableEndpointState
- s.Init(&ep)
-
- if joined, err := s.JoinGroup(addr1); err != nil {
- t.Fatalf("JoinGroup(%s): %s", addr1, err)
- } else if !joined {
- t.Errorf("got JoinGroup(%s) = false, want = true", addr1)
- }
- if joined, err := s.JoinGroup(addr2); err != nil {
- t.Fatalf("JoinGroup(%s): %s", addr2, err)
- } else if !joined {
- t.Errorf("got JoinGroup(%s) = false, want = true", addr2)
- }
-
- joinedGroups := s.JoinedGroups()
- sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] })
- if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" {
- t.Errorf("joined groups mismatch (-want +got):\n%s", diff)
- }
-}
-
// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
// endpoint state removes permanent addresses and leaves groups.
func TestAddressableEndpointStateCleanup(t *testing.T) {
@@ -81,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 5f3757de0..1805a8e0a 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -563,8 +563,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -580,11 +579,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 43ca03ada..236e4e4c7 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -291,14 +291,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool