diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 130 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol_test.go | 156 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp_test.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 85 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 83 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld_test.go | 272 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/network/multicast_group_test.go | 200 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 6 |
13 files changed, 949 insertions, 138 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index c8c17ab15..f85c5ff9d 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -30,6 +30,23 @@ type hostState int // The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 // (RFC 2710 section 5). Even though the states are generic across both IGMPv2 // and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. const ( // nonMember is the "'Non-Member' state, when the host does not belong to the // group on the interface. This is the initial state for all memberships on @@ -41,6 +58,15 @@ const ( // but without advertising the membership to the network. nonMember hostState = iota + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + // delayingMember is the "'Delaying Member' state, when the host belongs to // the group on the interface and has a report delay timer running for that // membership." @@ -48,6 +74,16 @@ const ( // 'Delaying Listener' is the MLDv1 term used to describe this state. delayingMember + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + // idleMember is the "Idle Member" state, when the host belongs to the group // on the interface and does not have a report delay timer running for that // membership. @@ -56,6 +92,17 @@ const ( idleMember ) +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + // multicastGroupState holds the Generic Multicast Protocol state for a // multicast group. type multicastGroupState struct { @@ -124,7 +171,10 @@ type GenericMulticastProtocolOptions struct { // can be represented by GenericMulticastProtocolState. type MulticastGroupProtocol interface { // SendReport sends a multicast report for the specified group address. - SendReport(groupAddress tcpip.Address) *tcpip.Error + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) *tcpip.Error @@ -166,6 +216,9 @@ type GenericMulticastProtocolState struct { // // The GenericMulticastProtocolState will only grab the lock when timers/jobs // fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { if g.memberships != nil { panic("attempted to initialize generic membership protocol state twice") @@ -212,6 +265,29 @@ func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { } } +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} + // JoinGroupLocked handles joining a new group. // // If dontInitialize is true, the group will be not be initialized and will be @@ -239,8 +315,7 @@ func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Addre panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - info.state = idleMember + g.maybeSendDelayedReportLocked(groupAddress, &info) g.memberships[groupAddress] = info }), } @@ -347,7 +422,7 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember @@ -360,10 +435,10 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad // Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { - panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) } - info.state = idleMember + info.lastToSendReport = false if groupAddress == g.opts.AllNodesAddress { // As per RFC 2236 section 6 page 10 (for IGMPv2), @@ -379,9 +454,25 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // case. The node starts in Idle Listener state for that address on // every interface, never transitions to another state, and never sends // a Report or Done for that address. + info.state = idleMember return } + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + // As per RFC 2236 section 3 page 5 (for IGMPv2), // // When a host joins a multicast group, it should immediately transmit an @@ -399,8 +490,30 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // // TODO(gvisor.dev/issue/4901): Support a configurable number of initial // unsolicited reports. - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } } // maybeSendLeave attempts to send a leave message. @@ -531,6 +644,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // TODO: Reset the timer if time remaining is greater than maxResponseTime. return } + info.state = delayingMember info.delayedReportJob.Cancel() info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 6fd0eb9f7..95040515c 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -47,6 +47,9 @@ type mockMulticastGroupProtocol struct { // Must only be accessed with mu held. sendLeaveGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + makeQueuePackets bool } func (m *mockMulticastGroupProtocol) init() { @@ -60,7 +63,7 @@ func (m *mockMulticastGroupProtocol) initLocked() { m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) @@ -71,7 +74,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcp } m.sendReportGroupAddrCount[groupAddress]++ - return nil + return !m.makeQueuePackets, nil } func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { @@ -112,7 +115,7 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t cmp.FilterPath( func(p cmp.Path) bool { - return p.Last().String() == ".mu" || p.Last().String() == ".t" + return p.Last().String() == ".mu" || p.Last().String() == ".t" || p.Last().String() == ".makeQueuePackets" }, cmp.Ignore(), ), @@ -732,3 +735,150 @@ func TestGroupStateNonMember(t *testing.T) { }) } } + +func TestQueuedPackets(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.mu.Lock() + g.HandleReportLocked(addr1) + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.HandleReportLocked(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index a3a7176a0..fb7a9e68e 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -96,7 +96,9 @@ type igmpState struct { } // SendReport implements ip.MulticastGroupProtocol. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -105,6 +107,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag @@ -113,7 +117,8 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { if igmp.v1Present() { return nil } - return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err } // init sets up an igmpState struct, and is required to be called before using @@ -235,9 +240,10 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) } -// writePacket assembles and sends an IGMP packet with the provided fields, -// incrementing the provided stat counter on success. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error { +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -248,9 +254,13 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip Data: buffer.View(igmpData).ToVectorisedView(), }) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddr := header.IPv4Any + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, @@ -259,22 +269,22 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip &header.IPv4SerializableRouterAlertOption{}, }) - sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() - return err + sentStats.Dropped.Increment() + return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sent.V1MembershipReport.Increment() + sentStats.V1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sent.V2MembershipReport.Increment() + sentStats.V2MembershipReport.Increment() case header.IGMPLeaveGroup: - sent.LeaveGroup.Increment() + sentStats.LeaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } - return nil + return true, nil } // joinGroup handles adding a new group to the membership map, setting up the @@ -325,3 +335,10 @@ func (igmp *igmpState) softLeaveAll() { func (igmp *igmpState) initializeAll() { igmp.genericMulticastProtocol.InitializeGroupsLocked() } + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 5e139377b..1ee573ac8 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -16,6 +16,7 @@ package ipv4_test import ( "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -29,6 +30,7 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 ) @@ -41,6 +43,7 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -71,7 +74,6 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - return e, s, clock } @@ -104,6 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // reports for backwards compatibility. func TestIgmpV1Present(t *testing.T) { e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) @@ -154,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) { } validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) } + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c63ecca4a..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -172,7 +172,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } @@ -189,6 +189,10 @@ func (e *endpoint) disableLocked() { if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -780,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -815,6 +824,14 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 5e75c8740..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -58,7 +58,10 @@ go_test( srcs = ["mld_test.go"], deps = [ ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7288e309c..e506e99e9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -121,6 +121,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -333,7 +372,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -349,6 +388,10 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1176,13 +1219,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if err := e.joinGroupLocked(snmc); err != nil { - // joinGroupLocked only returns an error if the group address is not a valid - // IPv6 multicast address. - panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1191,6 +1227,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1292,6 +1335,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1311,10 +1374,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1330,6 +1393,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address addressEndpoint: addressEndpoint, scope: scope, }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index b67eafdba..48644d9c8 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -59,13 +59,18 @@ type mldState struct { } // SendReport implements ip.MulticastGroupProtocol. -func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err } // init sets up an mldState struct, and is required to be called before using @@ -147,7 +152,17 @@ func (mld *mldState) initializeAll() { mld.genericMulticastProtocol.InitializeGroupsLocked() } -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() +} + +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent var mldStat *tcpip.StatCounter switch mldType { @@ -162,9 +177,61 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) icmp.SetType(mldType) header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddress := header.IPv6Any + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -180,8 +247,8 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp // Membership Reports. if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sentStats.Dropped.Increment() - return err + return false, err } mldStat.Increment() - return nil + return localAddress != header.IPv6Any, nil } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 5677bdd54..93b8b3c5c 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,8 +16,12 @@ package ipv6_test import ( "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -25,9 +29,31 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6(t, p, + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { const nicID = 1 @@ -46,45 +72,223 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a report message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(snmc), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) } // The stack will leave an address's solicited node multicast address when // an address is removed. An MLD done message should be sent for the // solicited-node group. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a done message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for _ = range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 2f5e2e82c..3b892aeda 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -647,6 +647,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } @@ -690,12 +691,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } + + ndp.ep.onAddressAssignedLocked(addr) } }), } diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 152986026..6579cd3c9 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -35,6 +35,9 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") @@ -49,6 +52,8 @@ const ( mldQuery = uint8(header.ICMPv6MulticastListenerQuery) mldReport = uint8(header.ICMPv6MulticastListenerReport) mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 ) var ( @@ -62,6 +67,8 @@ var ( } return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) ) // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet @@ -71,6 +78,7 @@ func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.A payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv6(t, payload, + checker.SrcAddr(ipv6Addr), checker.DstAddr(remoteAddress), // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. checker.TTL(1), @@ -88,6 +96,7 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -100,30 +109,31 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. ) } -func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() - // Create an endpoint of queue size 2, since no more than 2 packets are ever - // queued in the tests in this file. - e := channel.New(2, header.IPv6MinimumMTU, linkAddr) - s, clock := createStackWithLinkEndpoint(t, mgpEnabled, e) + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) return e, s, clock } -func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { t.Helper() + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocolWithOptions(ipv4.Options{ IGMP: ipv4.IGMPOptions{ - Enabled: mgpEnabled, + Enabled: igmpEnabled, }, }), ipv6.NewProtocolWithOptions(ipv6.Options{ MLD: ipv6.MLDOptions{ - Enabled: mgpEnabled, + Enabled: mldEnabled, }, }), }, @@ -132,10 +142,61 @@ func createStackWithLinkEndpoint(t *testing.T, mgpEnabled bool, e stack.LinkEndp if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } return s, clock } +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter +} + // createAndInjectIGMPPacket creates and injects an IGMP packet with the // specified fields. // @@ -240,13 +301,13 @@ func TestMGPDisabled(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, false) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) // This NIC may join multicast groups when it is enabled but since MGP is // disabled, no reports should be sent. sentReportStat := test.sentReportStat(s) if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -259,7 +320,7 @@ func TestMGPDisabled(t *testing.T) { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -363,7 +424,7 @@ func TestMGPReceiveCounters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) if got := test.statCounter(s).Value(); got != 1 { @@ -384,6 +445,7 @@ func TestMGPJoinGroup(t *testing.T) { sentReportStat func(*stack.Stack) *tcpip.StatCounter receivedQueryStat func(*stack.Stack) *tcpip.StatCounter validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -418,21 +480,28 @@ func TestMGPJoinGroup(t *testing.T) { validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } // Test joining a specific address explicitly and verify a Report is sent // immediately. if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } + reportCounter++ sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportState.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -450,8 +519,9 @@ func TestMGPJoinGroup(t *testing.T) { t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) } clock.Advance(test.maxUnsolicitedResponseDelay) - if got := sentReportStat.Value(); got != 2 { - t.Errorf("got sentReportState.Value() = %d, want = 2", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -472,13 +542,14 @@ func TestMGPJoinGroup(t *testing.T) { // group the stack sends a leave/done message. func TestMGPLeaveGroup(t *testing.T) { tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -521,18 +592,26 @@ func TestMGPLeaveGroup(t *testing.T) { validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } - if got := test.sentReportStat(s).Value(); got != 1 { - t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got) + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -547,8 +626,9 @@ func TestMGPLeaveGroup(t *testing.T) { if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } - if got := test.sentLeaveStat(s).Value(); got != 1 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got) + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a leave message to be sent") @@ -578,6 +658,7 @@ func TestMGPQueryMessages(t *testing.T) { rxQuery func(*channel.Endpoint, uint8, tcpip.Address) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -622,6 +703,7 @@ func TestMGPQueryMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } @@ -655,16 +737,22 @@ func TestMGPQueryMessages(t *testing.T) { for _, subTest := range subTests { t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - for i := uint64(1); i <= 2; i++ { + for i := 0; i < maxUnsolicitedReports; i++ { sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != i { - t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatalf("expected %d-th report message to be sent", i) @@ -694,8 +782,9 @@ func TestMGPQueryMessages(t *testing.T) { if subTest.expectReport { clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - if got := sentReportStat.Value(); got != 3 { - t.Errorf("got sentReportState.Value() = %d, want = 3", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -727,6 +816,7 @@ func TestMGPReportMessages(t *testing.T) { rxReport func(*channel.Endpoint) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -769,19 +859,27 @@ func TestMGPReportMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -796,8 +894,8 @@ func TestMGPReportMessages(t *testing.T) { // reports. test.rxReport(e) clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); ok { t.Errorf("sent unexpected packet = %#v", p) @@ -812,8 +910,8 @@ func TestMGPReportMessages(t *testing.T) { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != 0 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } // Should not send any more packets. @@ -837,6 +935,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -922,17 +1021,22 @@ func TestMGPWithNICLifecycle(t *testing.T) { } seen[addr] = true return addr - }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - sentReportStat := test.sentReportStat(s) var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) for _, a := range test.multicastAddrs { if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) @@ -957,7 +1061,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { t.Fatalf("DisableNIC(%d): %s", nicID, err) } sentLeaveStat := test.sentLeaveStat(s) - leaveCounter := uint64(len(test.multicastAddrs)) + leaveCounter += uint64(len(test.multicastAddrs)) if got := sentLeaveStat.Value(); got != leaveCounter { t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) } @@ -1059,7 +1163,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { clock.Advance(test.maxUnsolicitedResponseDelay) reportCounter++ if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter) + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -1105,7 +1209,7 @@ func TestMGPDisabledOnLoopback(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s, clock := createStackWithLinkEndpoint(t, true /* mgpEnabled */, loopback.New()) + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) sentReportStat := test.sentReportStat(s) if got := sentReportStat.Value(); got != 0 { diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 6e4f5fa46..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -82,12 +82,16 @@ func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) } // ForEachPrimaryEndpoint calls f for each primary address. -func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { +// +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { a.mu.RLock() defer a.mu.RUnlock() for _, ep := range a.mu.primary { - f(ep) + if !f(ep) { + return + } } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5887aa1ed..a6237dd5f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -172,7 +172,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -184,6 +184,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. |