From 53a95ad0dfe6123df5dd2bef5acfb81ebd796ff6 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 10 Dec 2020 14:47:53 -0800 Subject: Use specified source address for IGMP/MLD packets This change also considers interfaces and network endpoints enabled up up to the point all work to disable them are complete. This was needed so that protocols can perform shutdown work while being disabled (e.g. sending a packet which requires the endpoint to be enabled to obtain a source address). Bug #4682, #4861 Fixes #4888 Startblock: has LGTM from peterjohnston and then add reviewer brunodalbo PiperOrigin-RevId: 346869702 --- pkg/tcpip/network/ip/generic_multicast_protocol.go | 130 +++++++++- .../network/ip/generic_multicast_protocol_test.go | 156 +++++++++++- pkg/tcpip/network/ipv4/igmp.go | 47 ++-- pkg/tcpip/network/ipv4/igmp_test.go | 61 ++++- pkg/tcpip/network/ipv4/ipv4.go | 21 +- pkg/tcpip/network/ipv6/BUILD | 3 + pkg/tcpip/network/ipv6/ipv6.go | 85 ++++++- pkg/tcpip/network/ipv6/mld.go | 83 ++++++- pkg/tcpip/network/ipv6/mld_test.go | 272 ++++++++++++++++++--- pkg/tcpip/network/ipv6/ndp.go | 15 +- pkg/tcpip/network/multicast_group_test.go | 200 +++++++++++---- pkg/tcpip/stack/addressable_endpoint_state.go | 8 +- pkg/tcpip/stack/nic.go | 6 +- 13 files changed, 949 insertions(+), 138 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index c8c17ab15..f85c5ff9d 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -30,6 +30,23 @@ type hostState int // The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 // (RFC 2710 section 5). Even though the states are generic across both IGMPv2 // and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. const ( // nonMember is the "'Non-Member' state, when the host does not belong to the // group on the interface. This is the initial state for all memberships on @@ -41,6 +58,15 @@ const ( // but without advertising the membership to the network. nonMember hostState = iota + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + // delayingMember is the "'Delaying Member' state, when the host belongs to // the group on the interface and has a report delay timer running for that // membership." @@ -48,6 +74,16 @@ const ( // 'Delaying Listener' is the MLDv1 term used to describe this state. delayingMember + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + // idleMember is the "Idle Member" state, when the host belongs to the group // on the interface and does not have a report delay timer running for that // membership. @@ -56,6 +92,17 @@ const ( idleMember ) +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + // multicastGroupState holds the Generic Multicast Protocol state for a // multicast group. type multicastGroupState struct { @@ -124,7 +171,10 @@ type GenericMulticastProtocolOptions struct { // can be represented by GenericMulticastProtocolState. type MulticastGroupProtocol interface { // SendReport sends a multicast report for the specified group address. - SendReport(groupAddress tcpip.Address) *tcpip.Error + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) *tcpip.Error @@ -166,6 +216,9 @@ type GenericMulticastProtocolState struct { // // The GenericMulticastProtocolState will only grab the lock when timers/jobs // fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { if g.memberships != nil { panic("attempted to initialize generic membership protocol state twice") @@ -212,6 +265,29 @@ func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { } } +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} + // JoinGroupLocked handles joining a new group. // // If dontInitialize is true, the group will be not be initialized and will be @@ -239,8 +315,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - info.state = idleMember + g.maybeSendDelayedReportLocked(groupAddress, &info) g.memberships[groupAddress] = info }), } @@ -347,7 +422,7 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember @@ -360,10 +435,10 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad // Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { - panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) } - info.state = idleMember + info.lastToSendReport = false if groupAddress == g.opts.AllNodesAddress { // As per RFC 2236 section 6 page 10 (for IGMPv2), @@ -379,9 +454,25 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // case. The node starts in Idle Listener state for that address on // every interface, never transitions to another state, and never sends // a Report or Done for that address. + info.state = idleMember return } + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + // As per RFC 2236 section 3 page 5 (for IGMPv2), // // When a host joins a multicast group, it should immediately transmit an @@ -399,8 +490,30 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // // TODO(gvisor.dev/issue/4901): Support a configurable number of initial // unsolicited reports. - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } } // maybeSendLeave attempts to send a leave message. @@ -531,6 +644,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // TODO: Reset the timer if time remaining is greater than maxResponseTime. return } + info.state = delayingMember info.delayedReportJob.Cancel() info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 6fd0eb9f7..95040515c 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -47,6 +47,9 @@ type mockMulticastGroupProtocol struct { // Must only be accessed with mu held. sendLeaveGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + makeQueuePackets bool } func (m *mockMulticastGroupProtocol) init() { @@ -60,7 +63,7 @@ func (m *mockMulticastGroupProtocol) initLocked() { m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) @@ -71,7 +74,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcp } m.sendReportGroupAddrCount[groupAddress]++ - return nil + return !m.makeQueuePackets, nil } func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { @@ -112,7 +115,7 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t cmp.FilterPath( func(p cmp.Path) bool { - return p.Last().String() == ".mu" || p.Last().String() == ".t" + return p.Last().String() == ".mu" || p.Last().String() == ".t" || p.Last().String() == ".makeQueuePackets" }, cmp.Ignore(), ), @@ -732,3 +735,150 @@ func TestGroupStateNonMember(t *testing.T) { }) } } + +func TestQueuedPackets(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.mu.Lock() + g.HandleReportLocked(addr1) + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.HandleReportLocked(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index a3a7176a0..fb7a9e68e 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -96,7 +96,9 @@ type igmpState struct { } // SendReport implements ip.MulticastGroupProtocol. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -105,6 +107,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag @@ -113,7 +117,8 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { if igmp.v1Present() { return nil } - return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err } // init sets up an igmpState struct, and is required to be called before using @@ -235,9 +240,10 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) } -// writePacket assembles and sends an IGMP packet with the provided fields, -// incrementing the provided stat counter on success. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error { +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -248,9 +254,13 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip Data: buffer.View(igmpData).ToVectorisedView(), }) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddr := header.IPv4Any + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, @@ -259,22 +269,22 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip &header.IPv4SerializableRouterAlertOption{}, }) - sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() - return err + sentStats.Dropped.Increment() + return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sent.V1MembershipReport.Increment() + sentStats.V1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sent.V2MembershipReport.Increment() + sentStats.V2MembershipReport.Increment() case header.IGMPLeaveGroup: - sent.LeaveGroup.Increment() + sentStats.LeaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } - return nil + return true, nil } // joinGroup handles adding a new group to the membership map, setting up the @@ -325,3 +335,10 @@ func (igmp *igmpState) softLeaveAll() { func (igmp *igmpState) initializeAll() { igmp.genericMulticastProtocol.InitializeGroupsLocked() } + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 5e139377b..1ee573ac8 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -16,6 +16,7 @@ package ipv4_test import ( "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -29,6 +30,7 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 ) @@ -41,6 +43,7 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -71,7 +74,6 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - return e, s, clock } @@ -104,6 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // reports for backwards compatibility. func TestIgmpV1Present(t *testing.T) { e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) @@ -154,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) { } validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) } + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c63ecca4a..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -172,7 +172,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } @@ -189,6 +189,10 @@ func (e *endpoint) disableLocked() { if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -780,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -815,6 +824,14 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 5e75c8740..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -58,7 +58,10 @@ go_test( srcs = ["mld_test.go"], deps = [ ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7288e309c..e506e99e9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -121,6 +121,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -333,7 +372,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -349,6 +388,10 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1176,13 +1219,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if err := e.joinGroupLocked(snmc); err != nil { - // joinGroupLocked only returns an error if the group address is not a valid - // IPv6 multicast address. - panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1191,6 +1227,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1292,6 +1335,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1311,10 +1374,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1330,6 +1393,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address addressEndpoint: addressEndpoint, scope: scope, }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index b67eafdba..48644d9c8 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -59,13 +59,18 @@ type mldState struct { } // SendReport implements ip.MulticastGroupProtocol. -func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err } // init sets up an mldState struct, and is required to be called before using @@ -147,7 +152,17 @@ func (mld *mldState) initializeAll() { mld.genericMulticastProtocol.InitializeGroupsLocked() } -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() +} + +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent var mldStat *tcpip.StatCounter switch mldType { @@ -162,9 +177,61 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) icmp.SetType(mldType) header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddress := header.IPv6Any + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -180,8 +247,8 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp // Membership Reports. if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sentStats.Dropped.Increment() - return err + return false, err } mldStat.Increment() - return nil + return localAddress != header.IPv6Any, nil } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 5677bdd54..93b8b3c5c 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,8 +16,12 @@ package ipv6_test import ( "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -25,9 +29,31 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6(t, p, + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { const nicID = 1 @@ -46,45 +72,223 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a report message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(snmc), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) } // The stack will leave an address's solicited node multicast address when // an address is removed. An MLD done message should be sent for the // solicited-node group. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a done message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for _ = range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 2f5e2e82c..3b892aeda 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -647,6 +647,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } @@ -690,12 +691,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } + + ndp.ep.onAddressAssignedLocked(addr) } }), } diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 152986026..6579cd3c9 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -35,6 +35,9 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") @@ -49,6 +52,8 @@ const ( mldQuery = uint8(header.ICMPv6MulticastListenerQuery) mldReport = uint8(header.ICMPv6MulticastListenerReport) mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 ) var ( @@ -62,6 +67,8 @@ var ( } return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) ) // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet @@ -71,6 +78,7 @@ func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.A payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv6(t, payload, + checker.SrcAddr(ipv6Addr), checker.DstAddr(remoteAddress), // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. checker.TTL(1), @@ -88,6 +96,7 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -100,30 +109,31 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. ) } -func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() - // Create an endpoint of queue size 2, since no more than 2 packets are ever - // queued in the tests in this file. - e := channel.New(2, header.IPv6MinimumMTU, linkAddr) - s, clock := createStackWithLinkEndpoint(t, mgpEnabled, e) + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) return e, s, clock } -func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { t.Helper() + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocolWithOptions(ipv4.Options{ IGMP: ipv4.IGMPOptions{ - Enabled: mgpEnabled, + Enabled: igmpEnabled, }, }), ipv6.NewProtocolWithOptions(ipv6.Options{ MLD: ipv6.MLDOptions{ - Enabled: mgpEnabled, + Enabled: mldEnabled, }, }), }, @@ -132,10 +142,61 @@ func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndp if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } return s, clock } +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter +} + // createAndInjectIGMPPacket creates and injects an IGMP packet with the // specified fields. // @@ -240,13 +301,13 @@ func TestMGPDisabled(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, false) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) // This NIC may join multicast groups when it is enabled but since MGP is // disabled, no reports should be sent. sentReportStat := test.sentReportStat(s) if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -259,7 +320,7 @@ func TestMGPDisabled(t *testing.T) { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -363,7 +424,7 @@ func TestMGPReceiveCounters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) if got := test.statCounter(s).Value(); got != 1 { @@ -384,6 +445,7 @@ func TestMGPJoinGroup(t *testing.T) { sentReportStat func(*stack.Stack) *tcpip.StatCounter receivedQueryStat func(*stack.Stack) *tcpip.StatCounter validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -418,21 +480,28 @@ func TestMGPJoinGroup(t *testing.T) { validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } // Test joining a specific address explicitly and verify a Report is sent // immediately. if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } + reportCounter++ sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportState.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -450,8 +519,9 @@ func TestMGPJoinGroup(t *testing.T) { t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) } clock.Advance(test.maxUnsolicitedResponseDelay) - if got := sentReportStat.Value(); got != 2 { - t.Errorf("got sentReportState.Value() = %d, want = 2", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -472,13 +542,14 @@ func TestMGPJoinGroup(t *testing.T) { // group the stack sends a leave/done message. func TestMGPLeaveGroup(t *testing.T) { tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -521,18 +592,26 @@ func TestMGPLeaveGroup(t *testing.T) { validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } - if got := test.sentReportStat(s).Value(); got != 1 { - t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got) + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -547,8 +626,9 @@ func TestMGPLeaveGroup(t *testing.T) { if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } - if got := test.sentLeaveStat(s).Value(); got != 1 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got) + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a leave message to be sent") @@ -578,6 +658,7 @@ func TestMGPQueryMessages(t *testing.T) { rxQuery func(*channel.Endpoint, uint8, tcpip.Address) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -622,6 +703,7 @@ func TestMGPQueryMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } @@ -655,16 +737,22 @@ func TestMGPQueryMessages(t *testing.T) { for _, subTest := range subTests { t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - for i := uint64(1); i <= 2; i++ { + for i := 0; i < maxUnsolicitedReports; i++ { sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != i { - t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatalf("expected %d-th report message to be sent", i) @@ -694,8 +782,9 @@ func TestMGPQueryMessages(t *testing.T) { if subTest.expectReport { clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - if got := sentReportStat.Value(); got != 3 { - t.Errorf("got sentReportState.Value() = %d, want = 3", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -727,6 +816,7 @@ func TestMGPReportMessages(t *testing.T) { rxReport func(*channel.Endpoint) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -769,19 +859,27 @@ func TestMGPReportMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -796,8 +894,8 @@ func TestMGPReportMessages(t *testing.T) { // reports. test.rxReport(e) clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); ok { t.Errorf("sent unexpected packet = %#v", p) @@ -812,8 +910,8 @@ func TestMGPReportMessages(t *testing.T) { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != 0 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } // Should not send any more packets. @@ -837,6 +935,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -922,17 +1021,22 @@ func TestMGPWithNICLifecycle(t *testing.T) { } seen[addr] = true return addr - }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - sentReportStat := test.sentReportStat(s) var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) for _, a := range test.multicastAddrs { if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) @@ -957,7 +1061,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { t.Fatalf("DisableNIC(%d): %s", nicID, err) } sentLeaveStat := test.sentLeaveStat(s) - leaveCounter := uint64(len(test.multicastAddrs)) + leaveCounter += uint64(len(test.multicastAddrs)) if got := sentLeaveStat.Value(); got != leaveCounter { t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) } @@ -1059,7 +1163,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { clock.Advance(test.maxUnsolicitedResponseDelay) reportCounter++ if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter) + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -1105,7 +1209,7 @@ func TestMGPDisabledOnLoopback(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s, clock := createStackWithLinkEndpoint(t, true /* mgpEnabled */, loopback.New()) + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) sentReportStat := test.sentReportStat(s) if got := sentReportStat.Value(); got != 0 { diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 6e4f5fa46..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -82,12 +82,16 @@ func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) } // ForEachPrimaryEndpoint calls f for each primary address. -func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { +// +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { a.mu.RLock() defer a.mu.RUnlock() for _, ep := range a.mu.primary { - f(ep) + if !f(ep) { + return + } } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5887aa1ed..a6237dd5f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -172,7 +172,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -184,6 +184,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. -- cgit v1.2.3