summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go130
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go156
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go47
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go61
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go21
-rw-r--r--pkg/tcpip/network/ipv6/BUILD3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go85
-rw-r--r--pkg/tcpip/network/ipv6/mld.go83
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go272
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go15
-rw-r--r--pkg/tcpip/network/multicast_group_test.go200
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go8
-rw-r--r--pkg/tcpip/stack/nic.go6
13 files changed, 949 insertions, 138 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index c8c17ab15..f85c5ff9d 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -30,6 +30,23 @@ type hostState int
// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
const (
// nonMember is the "'Non-Member' state, when the host does not belong to the
// group on the interface. This is the initial state for all memberships on
@@ -41,6 +58,15 @@ const (
// but without advertising the membership to the network.
nonMember hostState = iota
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
+
// delayingMember is the "'Delaying Member' state, when the host belongs to
// the group on the interface and has a report delay timer running for that
// membership."
@@ -48,6 +74,16 @@ const (
// 'Delaying Listener' is the MLDv1 term used to describe this state.
delayingMember
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
// idleMember is the "Idle Member" state, when the host belongs to the group
// on the interface and does not have a report delay timer running for that
// membership.
@@ -56,6 +92,17 @@ const (
idleMember
)
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
// multicastGroupState holds the Generic Multicast Protocol state for a
// multicast group.
type multicastGroupState struct {
@@ -124,7 +171,10 @@ type GenericMulticastProtocolOptions struct {
// can be represented by GenericMulticastProtocolState.
type MulticastGroupProtocol interface {
// SendReport sends a multicast report for the specified group address.
- SendReport(groupAddress tcpip.Address) *tcpip.Error
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
// SendLeave sends a multicast leave for the specified group address.
SendLeave(groupAddress tcpip.Address) *tcpip.Error
@@ -166,6 +216,9 @@ type GenericMulticastProtocolState struct {
//
// The GenericMulticastProtocolState will only grab the lock when timers/jobs
// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
if g.memberships != nil {
panic("attempted to initialize generic membership protocol state twice")
@@ -212,6 +265,29 @@ func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
}
}
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
+
// JoinGroupLocked handles joining a new group.
//
// If dontInitialize is true, the group will be not be initialized and will be
@@ -239,8 +315,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
}
- info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
- info.state = idleMember
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
g.memberships[groupAddress] = info
}),
}
@@ -347,7 +422,7 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad
// multicast address while it has a timer running for that same address
// on that interface, it stops its timer and does not send a Report for
// that address, thus suppressing duplicate reports on the link.
- if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember {
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
info.delayedReportJob.Cancel()
info.lastToSendReport = false
info.state = idleMember
@@ -360,10 +435,10 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad
// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
if info.state != nonMember {
- panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state))
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
}
- info.state = idleMember
+ info.lastToSendReport = false
if groupAddress == g.opts.AllNodesAddress {
// As per RFC 2236 section 6 page 10 (for IGMPv2),
@@ -379,9 +454,25 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
// case. The node starts in Idle Listener state for that address on
// every interface, never transitions to another state, and never sends
// a Report or Done for that address.
+ info.state = idleMember
return
}
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
// As per RFC 2236 section 3 page 5 (for IGMPv2),
//
// When a host joins a multicast group, it should immediately transmit an
@@ -399,8 +490,30 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
//
// TODO(gvisor.dev/issue/4901): Support a configurable number of initial
// unsolicited reports.
- info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
- g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
+ }
}
// maybeSendLeave attempts to send a leave message.
@@ -531,6 +644,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
// TODO: Reset the timer if time remaining is greater than maxResponseTime.
return
}
+
info.state = delayingMember
info.delayedReportJob.Cancel()
info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 6fd0eb9f7..95040515c 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -47,6 +47,9 @@ type mockMulticastGroupProtocol struct {
// Must only be accessed with mu held.
sendLeaveGroupAddrCount map[tcpip.Address]int
+
+ // Must only be accessed with mu held.
+ makeQueuePackets bool
}
func (m *mockMulticastGroupProtocol) init() {
@@ -60,7 +63,7 @@ func (m *mockMulticastGroupProtocol) initLocked() {
m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
}
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
@@ -71,7 +74,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcp
}
m.sendReportGroupAddrCount[groupAddress]++
- return nil
+ return !m.makeQueuePackets, nil
}
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
@@ -112,7 +115,7 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr
// ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
cmp.FilterPath(
func(p cmp.Path) bool {
- return p.Last().String() == ".mu" || p.Last().String() == ".t"
+ return p.Last().String() == ".mu" || p.Last().String() == ".t" || p.Last().String() == ".makeQueuePackets"
},
cmp.Ignore(),
),
@@ -732,3 +735,150 @@ func TestGroupStateNonMember(t *testing.T) {
})
}
}
+
+func TestQueuedPackets(t *testing.T) {
+ var g ip.GenericMulticastProtocolState
+ var mgp mockMulticastGroupProtocol
+ mgp.init()
+ clock := faketime.NewManualClock()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
+ Enabled: true,
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ Protocol: &mgp,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.mu.Lock()
+ g.HandleReportLocked(addr1)
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = true
+ g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.mu.Lock()
+ g.HandleReportLocked(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.mu.Lock()
+ mgp.makeQueuePackets = false
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.mu.Lock()
+ g.SendQueuedReportsLocked()
+ mgp.mu.Unlock()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index a3a7176a0..fb7a9e68e 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -96,7 +96,9 @@ type igmpState struct {
}
// SendReport implements ip.MulticastGroupProtocol.
-func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
if igmp.v1Present() {
igmpType = header.IGMPv1MembershipReport
@@ -105,6 +107,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
}
// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
@@ -113,7 +117,8 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
if igmp.v1Present() {
return nil
}
- return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ return err
}
// init sets up an igmpState struct, and is required to be called before using
@@ -235,9 +240,10 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
}
-// writePacket assembles and sends an IGMP packet with the provided fields,
-// incrementing the provided stat counter on success.
-func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error {
+// writePacket assembles and sends an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
@@ -248,9 +254,13 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
Data: buffer.View(igmpData).ToVectorisedView(),
})
- // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
- // rather we should select an appropriate local address.
- localAddr := header.IPv4Any
+ addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return false, nil
+ }
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
+ addressEndpoint = nil
igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
@@ -259,22 +269,22 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
&header.IPv4SerializableRouterAlertOption{},
})
- sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
- return err
+ sentStats.Dropped.Increment()
+ return false, err
}
switch igmpType {
case header.IGMPv1MembershipReport:
- sent.V1MembershipReport.Increment()
+ sentStats.V1MembershipReport.Increment()
case header.IGMPv2MembershipReport:
- sent.V2MembershipReport.Increment()
+ sentStats.V2MembershipReport.Increment()
case header.IGMPLeaveGroup:
- sent.LeaveGroup.Increment()
+ sentStats.LeaveGroup.Increment()
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
- return nil
+ return true, nil
}
// joinGroup handles adding a new group to the membership map, setting up the
@@ -325,3 +335,10 @@ func (igmp *igmpState) softLeaveAll() {
func (igmp *igmpState) initializeAll() {
igmp.genericMulticastProtocol.InitializeGroupsLocked()
}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) sendQueuedReports() {
+ igmp.genericMulticastProtocol.SendQueuedReportsLocked()
+}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 5e139377b..1ee573ac8 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -16,6 +16,7 @@ package ipv4_test
import (
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -29,6 +30,7 @@ import (
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ addr = tcpip.Address("\x0a\x00\x00\x01")
multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
nicID = 1
)
@@ -41,6 +43,7 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
checker.IPv4(t, payload,
+ checker.SrcAddr(addr),
checker.DstAddr(remoteAddress),
// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
checker.TTL(1),
@@ -71,7 +74,6 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
-
return e, s, clock
}
@@ -104,6 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// reports for backwards compatibility.
func TestIgmpV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
@@ -154,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) {
}
validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
}
+
+func TestSendQueuedIGMPReports(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Joining a group without an assigned address should queue IGMP packets; none
+ // should be sent without an assigned address.
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
+ }
+ reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
+ if got := reportStat.Value(); got != 0 {
+ t.Errorf("got reportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+
+ // The initial set of IGMP reports that were queued should be sent once an
+ // address is assigned.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+ if got := reportStat.Value(); got != 1 {
+ t.Errorf("got reportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ if got := reportStat.Value(); got != 2 {
+ t.Errorf("got reportStat.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should have no more packets to send after the initial set of unsolicited
+ // reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c63ecca4a..e9ff70d04 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -172,7 +172,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.isEnabled() {
return
}
@@ -189,6 +189,10 @@ func (e *endpoint) disableLocked() {
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -780,7 +784,12 @@ func (e *endpoint) Close() {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err == nil {
+ e.mu.igmp.sendQueuedReports()
+ }
+ return ep, err
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
@@ -815,6 +824,14 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 5e75c8740..afa45aefe 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -58,7 +58,10 @@ go_test(
srcs = ["mld_test.go"],
deps = [
":ipv6",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 7288e309c..e506e99e9 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -121,6 +121,45 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// onAddressAssignedLocked handles an address being assigned.
+//
+// Precondition: e.mu must be exclusively locked.
+func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, ...
+ //
+ // If we just completed DAD for a link-local address, then attempt to send any
+ // queued MLD reports. Note, we may have sent reports already for some of the
+ // groups before we had a valid link-local address to use as the source for
+ // the MLD messages, but that was only so that MLD snooping switches are aware
+ // of our membership to groups - routers would not have handled those reports.
+ //
+ // As per RFC 3590 section 4,
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ if header.IsV6LinkLocalAddress(addr) {
+ e.mu.mld.sendQueuedReports()
+ }
+}
+
// InvalidateDefaultRouter implements stack.NDPEndpoint.
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
@@ -333,7 +372,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.Enabled() {
return
}
@@ -349,6 +388,10 @@ func (e *endpoint) disableLocked() {
// Leave groups from the perspective of MLD so that routers know that
// we are no longer interested in the group.
e.mu.mld.softLeaveAll()
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -1176,13 +1219,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
return addressEndpoint, nil
}
- snmc := header.SolicitedNodeAddr(addr.Address)
- if err := e.joinGroupLocked(snmc); err != nil {
- // joinGroupLocked only returns an error if the group address is not a valid
- // IPv6 multicast address.
- panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
- }
-
addressEndpoint.SetKind(stack.PermanentTentative)
if e.Enabled() {
@@ -1191,6 +1227,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
}
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
+ }
+
return addressEndpoint, nil
}
@@ -1292,6 +1335,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
+// getLinkLocalAddressRLocked returns a link-local address from the primary list
+// of addresses, if one is available.
+//
+// See stack.PrimaryEndpointBehavior for more details about the primary list.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
+ var linkLocalAddr tcpip.Address
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.IsAssigned(false /* allowExpired */) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ linkLocalAddr = addr
+ return false
+ }
+ }
+ return true
+ })
+ return linkLocalAddr
+}
+
// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
// but with locking requirements.
//
@@ -1311,10 +1374,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
- return
+ return true
}
addr := addressEndpoint.AddressWithPrefix().Address
@@ -1330,6 +1393,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
addressEndpoint: addressEndpoint,
scope: scope,
})
+
+ return true
})
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index b67eafdba..48644d9c8 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -59,13 +59,18 @@ type mldState struct {
}
// SendReport implements ip.MulticastGroupProtocol.
-func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
}
// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
- return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ return err
}
// init sets up an mldState struct, and is required to be called before using
@@ -147,7 +152,17 @@ func (mld *mldState) initializeAll() {
mld.genericMulticastProtocol.InitializeGroupsLocked()
}
-func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error {
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) sendQueuedReports() {
+ mld.genericMulticastProtocol.SendQueuedReportsLocked()
+}
+
+// writePacket assembles and sends an MLD packet.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
var mldStat *tcpip.StatCounter
switch mldType {
@@ -162,9 +177,61 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
icmp.SetType(mldType)
header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
- // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
- // rather we should select an appropriate local address.
- localAddress := header.IPv6Any
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
+ // option in a Hop-by-Hop Options header.
+ //
+ // However, this would cause problems with Duplicate Address Detection with
+ // the first address as MLD snooping switches may not send multicast traffic
+ // that DAD depends on to the node performing DAD without the MLD report, as
+ // documented in RFC 4816:
+ //
+ // Note that when a node joins a multicast address, it typically sends a
+ // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
+ // for the multicast address. In the case of Duplicate Address
+ // Detection, the MLD report message is required in order to inform MLD-
+ // snooping switches, rather than routers, to forward multicast packets.
+ // In the above description, the delay for joining the multicast address
+ // thus means delaying transmission of the corresponding MLD report
+ // message. Since the MLD specifications do not request a random delay
+ // to avoid race conditions, just delaying Neighbor Solicitation would
+ // cause congestion by the MLD report messages. The congestion would
+ // then prevent the MLD-snooping switches from working correctly and, as
+ // a result, prevent Duplicate Address Detection from working. The
+ // requirement to include the delay for the MLD report in this case
+ // avoids this scenario. [RFC3590] also talks about some interaction
+ // issues between Duplicate Address Detection and MLD, and specifies
+ // which source address should be used for the MLD report in this case.
+ //
+ // As per RFC 3590 section 4, we should still send out MLD reports with an
+ // unspecified source address if we do not have an assigned link-local
+ // address to use as the source address to ensure DAD works as expected on
+ // networks with MLD snooping switches:
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ localAddress := mld.ep.getLinkLocalAddressRLocked()
+ if len(localAddress) == 0 {
+ localAddress = header.IPv6Any
+ }
+
icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -180,8 +247,8 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
// Membership Reports.
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sentStats.Dropped.Increment()
- return err
+ return false, err
}
mldStat.Increment()
- return nil
+ return localAddress != header.IPv6Any, nil
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 5677bdd54..93b8b3c5c 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -16,8 +16,12 @@ package ipv6_test
import (
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -25,9 +29,31 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
+var (
+ linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
+ globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
+)
+
+func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
+ t.Helper()
+
+ checker.IPv6(t, p,
+ checker.SrcAddr(localAddress),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(mldType, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
const nicID = 1
@@ -46,45 +72,223 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
}
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected a report message to be sent")
- }
- snmc := header.SolicitedNodeAddr(addr1)
- checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
- checker.DstAddr(snmc),
- // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
- checker.TTL(1),
- checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize,
- checker.MLDMaxRespDelay(0),
- checker.MLDMulticastAddress(snmc),
- ),
- )
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
}
// The stack will leave an address's solicited node multicast address when
// an address is removed. An MLD done message should be sent for the
// solicited-node group.
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a done message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
}
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected a done message to be sent")
- }
- snmc := header.SolicitedNodeAddr(addr1)
- checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(1),
- checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize,
- checker.MLDMaxRespDelay(0),
- checker.MLDMulticastAddress(snmc),
- ),
- )
+}
+
+func TestSendQueuedMLDReports(t *testing.T) {
+ const (
+ nicID = 1
+ maxReports = 2
+ )
+
+ tests := []struct {
+ name string
+ dadTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD Disabled",
+ dadTransmits: 0,
+ retransmitTimer: 0,
+ },
+ {
+ name: "DAD Enabled",
+ dadTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dadTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ Clock: clock,
+ })
+
+ // Allow space for an extra packet so we can observe packets that were
+ // unexpectedly sent.
+ e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ resolveDAD := func(addr, snmc tcpip.Address) {
+ clock.Advance(dadResolutionTime)
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected DAD packet")
+ } else {
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+ }
+
+ var reportCounter uint64
+ reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ var doneCounter uint64
+ doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+
+ // Joining a group without an assigned address should send an MLD report
+ // with the unspecified address.
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalMulticastAddr)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a global address should not send reports for the already joined
+ // group since we should only send queued reports when a link-local
+ // addres sis assigned.
+ //
+ // Note, we will still expect to send a report for the global address's
+ // solicited node address from the unspecified address as per RFC 3590
+ // section 4.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
+ }
+ if dadResolutionTime != 0 {
+ // Reports should not be sent when the address resolves.
+ resolveDAD(globalAddr, globalAddrSNMC)
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ }
+ // Leave the group since we don't care about the global address's
+ // solicited node multicast group membership.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
+ }
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a link-local address should send a report for its solicited node
+ // address and globalMulticastAddr.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ }
+ if dadResolutionTime != 0 {
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+ resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
+ }
+
+ // We expect two batches of reports to be sent (1 batch when the
+ // link-local address is assigned, and another after the maximum
+ // unsolicited report interval.
+ for i := 0; i < 2; i++ {
+ // We expect reports to be sent (one for globalMulticastAddr and another
+ // for linkLocalAddrSNMC).
+ reportCounter += maxReports
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+
+ addrs := map[tcpip.Address]bool{
+ globalMulticastAddr: false,
+ linkLocalAddrSNMC: false,
+ }
+ for _ = range addrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
+ }
+
+ addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
+ if seen, ok := addrs[addr]; !ok {
+ t.Fatalf("got unexpected packet destined to %s", addr)
+ } else if seen {
+ t.Fatalf("got another packet destined to %s", addr)
+ }
+
+ addrs[addr] = true
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
+
+ clock.Advance(ipv6.UnsolicitedReportIntervalMax)
+ }
+ }
+
+ // Should not send any more reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 2f5e2e82c..3b892aeda 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -647,6 +647,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
+ ndp.ep.onAddressAssignedLocked(addr)
return nil
}
@@ -690,12 +691,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ if dadDone {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
+
+ ndp.ep.onAddressAssignedLocked(addr)
}
}),
}
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 152986026..6579cd3c9 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -35,6 +35,9 @@ import (
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ ipv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+
ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
@@ -49,6 +52,8 @@ const (
mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
mldReport = uint8(header.ICMPv6MulticastListenerReport)
mldDone = uint8(header.ICMPv6MulticastListenerDone)
+
+ maxUnsolicitedReports = 2
)
var (
@@ -62,6 +67,8 @@ var (
}
return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
}()
+
+ ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr)
)
// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
@@ -71,6 +78,7 @@ func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.A
payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
checker.IPv6(t, payload,
+ checker.SrcAddr(ipv6Addr),
checker.DstAddr(remoteAddress),
// Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
checker.TTL(1),
@@ -88,6 +96,7 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
checker.IPv4(t, payload,
+ checker.SrcAddr(ipv4Addr),
checker.DstAddr(remoteAddress),
// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
checker.TTL(1),
@@ -100,30 +109,31 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
)
}
-func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
t.Helper()
- // Create an endpoint of queue size 2, since no more than 2 packets are ever
- // queued in the tests in this file.
- e := channel.New(2, header.IPv6MinimumMTU, linkAddr)
- s, clock := createStackWithLinkEndpoint(t, mgpEnabled, e)
+ e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
+ s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
return e, s, clock
}
-func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
+func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
t.Helper()
+ igmpEnabled := v4 && mgpEnabled
+ mldEnabled := !v4 && mgpEnabled
+
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocolWithOptions(ipv4.Options{
IGMP: ipv4.IGMPOptions{
- Enabled: mgpEnabled,
+ Enabled: igmpEnabled,
},
}),
ipv6.NewProtocolWithOptions(ipv6.Options{
MLD: ipv6.MLDOptions{
- Enabled: mgpEnabled,
+ Enabled: mldEnabled,
},
}),
},
@@ -132,10 +142,61 @@ func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndp
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err)
+ }
return s, clock
}
+// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
+// when it is created with an IPv6 address.
+//
+// To not interfere with tests, checkInitialIPv6Groups will leave the added
+// address's solicited node multicast group so that the tests can all assume
+// the NIC has not joined any IPv6 groups.
+func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
+ t.Helper()
+
+ stats := s.Stats().ICMP.V6.PacketsSent
+
+ reportCounter++
+ if got := stats.MulticastListenerReport.Value(); got != reportCounter {
+ t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
+ }
+
+ // Leave the group to not affect the tests. This is fine since we are not
+ // testing DAD or the solicited node address specifically.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
+ }
+ leaveCounter++
+ if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
+ t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ return reportCounter, leaveCounter
+}
+
// createAndInjectIGMPPacket creates and injects an IGMP packet with the
// specified fields.
//
@@ -240,13 +301,13 @@ func TestMGPDisabled(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, false)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
// This NIC may join multicast groups when it is enabled but since MGP is
// disabled, no reports should be sent.
sentReportStat := test.sentReportStat(s)
if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
clock.Advance(time.Hour)
if p, ok := e.Read(); ok {
@@ -259,7 +320,7 @@ func TestMGPDisabled(t *testing.T) {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
clock.Advance(time.Hour)
if p, ok := e.Read(); ok {
@@ -363,7 +424,7 @@ func TestMGPReceiveCounters(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, true)
+ e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
if got := test.statCounter(s).Value(); got != 1 {
@@ -384,6 +445,7 @@ func TestMGPJoinGroup(t *testing.T) {
sentReportStat func(*stack.Stack) *tcpip.StatCounter
receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
validateReport func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -418,21 +480,28 @@ func TestMGPJoinGroup(t *testing.T) {
validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
// Test joining a specific address explicitly and verify a Report is sent
// immediately.
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
+ reportCounter++
sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 1 {
- t.Errorf("got sentReportState.Value() = %d, want = 1", got)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -450,8 +519,9 @@ func TestMGPJoinGroup(t *testing.T) {
t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
}
clock.Advance(test.maxUnsolicitedResponseDelay)
- if got := sentReportStat.Value(); got != 2 {
- t.Errorf("got sentReportState.Value() = %d, want = 2", got)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -472,13 +542,14 @@ func TestMGPJoinGroup(t *testing.T) {
// group the stack sends a leave/done message.
func TestMGPLeaveGroup(t *testing.T) {
tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo)
- validateLeave func(*testing.T, channel.PacketInfo)
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -521,18 +592,26 @@ func TestMGPLeaveGroup(t *testing.T) {
validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
- if got := test.sentReportStat(s).Value(); got != 1 {
- t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got)
+ reportCounter++
+ if got := test.sentReportStat(s).Value(); got != reportCounter {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -547,8 +626,9 @@ func TestMGPLeaveGroup(t *testing.T) {
if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
}
- if got := test.sentLeaveStat(s).Value(); got != 1 {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got)
+ leaveCounter++
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a leave message to be sent")
@@ -578,6 +658,7 @@ func TestMGPQueryMessages(t *testing.T) {
rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
validateReport func(*testing.T, channel.PacketInfo)
maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -622,6 +703,7 @@ func TestMGPQueryMessages(t *testing.T) {
maxRespTimeToDuration: func(d uint8) time.Duration {
return time.Duration(d) * time.Millisecond
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
@@ -655,16 +737,22 @@ func TestMGPQueryMessages(t *testing.T) {
for _, subTest := range subTests {
t.Run(subTest.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
sentReportStat := test.sentReportStat(s)
- for i := uint64(1); i <= 2; i++ {
+ for i := 0; i < maxUnsolicitedReports; i++ {
sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != i {
- t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatalf("expected %d-th report message to be sent", i)
@@ -694,8 +782,9 @@ func TestMGPQueryMessages(t *testing.T) {
if subTest.expectReport {
clock.Advance(test.maxRespTimeToDuration(maxRespTime))
- if got := sentReportStat.Value(); got != 3 {
- t.Errorf("got sentReportState.Value() = %d, want = 3", got)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -727,6 +816,7 @@ func TestMGPReportMessages(t *testing.T) {
rxReport func(*channel.Endpoint)
validateReport func(*testing.T, channel.PacketInfo)
maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -769,19 +859,27 @@ func TestMGPReportMessages(t *testing.T) {
maxRespTimeToDuration: func(d uint8) time.Duration {
return time.Duration(d) * time.Millisecond
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 1 {
- t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -796,8 +894,8 @@ func TestMGPReportMessages(t *testing.T) {
// reports.
test.rxReport(e)
clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != 1 {
- t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); ok {
t.Errorf("sent unexpected packet = %#v", p)
@@ -812,8 +910,8 @@ func TestMGPReportMessages(t *testing.T) {
t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
}
clock.Advance(time.Hour)
- if got := test.sentLeaveStat(s).Value(); got != 0 {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got)
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
}
// Should not send any more packets.
@@ -837,6 +935,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -922,17 +1021,22 @@ func TestMGPWithNICLifecycle(t *testing.T) {
}
seen[addr] = true
return addr
-
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
- sentReportStat := test.sentReportStat(s)
var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ sentReportStat := test.sentReportStat(s)
for _, a := range test.multicastAddrs {
if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
@@ -957,7 +1061,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
t.Fatalf("DisableNIC(%d): %s", nicID, err)
}
sentLeaveStat := test.sentLeaveStat(s)
- leaveCounter := uint64(len(test.multicastAddrs))
+ leaveCounter += uint64(len(test.multicastAddrs))
if got := sentLeaveStat.Value(); got != leaveCounter {
t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
}
@@ -1059,7 +1163,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
clock.Advance(test.maxUnsolicitedResponseDelay)
reportCounter++
if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter)
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -1105,7 +1209,7 @@ func TestMGPDisabledOnLoopback(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s, clock := createStackWithLinkEndpoint(t, true /* mgpEnabled */, loopback.New())
+ s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
sentReportStat := test.sentReportStat(s)
if got := sentReportStat.Value(); got != 0 {
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 6e4f5fa46..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -82,12 +82,16 @@ func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool)
}
// ForEachPrimaryEndpoint calls f for each primary address.
-func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
+//
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.mu.primary {
- f(ep)
+ if !f(ep) {
+ return
+ }
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 5887aa1ed..a6237dd5f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -172,7 +172,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -184,6 +184,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.