diff options
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/network/multicast_group_test.go | 327 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 28 |
5 files changed, 396 insertions, 38 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index be9c8e2f9..ce2087002 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -123,6 +123,15 @@ func (e *endpoint) Enable() *tcpip.Error { // We have no need for the address endpoint. ep.DecRef() + // Groups may have been joined while the endpoint was disabled, or the + // endpoint may have left groups from the perspective of IGMP when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + joinedGroups := e.mu.addressableEndpointState.JoinedGroups() + for _, group := range joinedGroups { + e.igmp.joinGroup(group) + } + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. @@ -168,6 +177,13 @@ func (e *endpoint) disableLocked() { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } + // Leave groups from the perspective of IGMP so that routers know that + // we are no longer interested in the group. + joinedGroups := e.mu.addressableEndpointState.JoinedGroups() + for _, group := range joinedGroups { + e.igmp.leaveGroup(group) + } + // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) @@ -853,6 +869,15 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { return joined, err } + // Only join the group from the perspective of IGMP when the endpoint is + // enabled. + // + // If we are not enabled right now, we will join the group from the + // perspective of IGMP when the endpoint is enabled. + if !e.Enabled() { + return true, nil + } + // joinGroup only returns an error if we try to join a group twice, but we // checked above to make sure that the group was newly joined. if err := e.igmp.joinGroup(addr); err != nil { @@ -874,15 +899,12 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { left, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err != nil { + if err != nil || !left { return left, err } - if left { - e.igmp.leaveGroup(addr) - } - - return left, nil + e.igmp.leaveGroup(addr) + return true, nil } // IsInGroup implements stack.GroupAddressableEndpoint. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index ac67d4ac5..4d49afcbb 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -228,6 +228,15 @@ func (e *endpoint) Enable() *tcpip.Error { return nil } + // Groups may have been joined when the endpoint was disabled, or the + // endpoint may have left groups from the perspective of MLD when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + joinedGroups := e.mu.addressableEndpointState.JoinedGroups() + for _, group := range joinedGroups { + e.mld.joinGroup(group) + } + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -338,6 +347,13 @@ func (e *endpoint) disableLocked() { if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } + + // Leave groups from the perspective of MLD so that routers know that + // we are no longer interested in the group. + joinedGroups := e.mu.addressableEndpointState.JoinedGroups() + for _, group := range joinedGroups { + e.mld.leaveGroup(group) + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1409,6 +1425,15 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { return joined, err } + // Only join the group from the perspective of IGMP when the endpoint is + // enabled. + // + // If we are not enabled right now, we will join the group from the + // perspective of MLD when the endpoint is enabled. + if !e.Enabled() { + return true, nil + } + // joinGroup only returns an error if we try to join a group twice, but we // checked above to make sure that the group was newly joined. if err := e.mld.joinGroup(addr); err != nil { @@ -1430,15 +1455,12 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { left, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err != nil { + if err != nil || !left { return left, err } - if left { - e.mld.leaveGroup(addr) - } - - return left, nil + e.mld.leaveGroup(addr) + return true, nil } // IsInGroup implements stack.GroupAddressableEndpoint. diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 12f9fbe82..d3517c364 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -34,8 +34,12 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - ipv4MulticastAddr = tcpip.Address("\xe0\x00\x00\x03") - ipv6MulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") + ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") + ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") + ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") + ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") igmpMembershipQuery = uint8(header.IGMPMembershipQuery) igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) @@ -97,9 +101,9 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() - // Create an endpoint of queue size 1, since no more than 1 packets are ever + // Create an endpoint of queue size 2, since no more than 2 packets are ever // queued in the tests in this file. - e := channel.New(1, 1280, linkAddr) + e := channel.New(2, 1280, linkAddr) clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -198,7 +202,7 @@ func TestMGPDisabled(t *testing.T) { { name: "IGMP", protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr, + multicastAddr: ipv4MulticastAddr1, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().IGMP.PacketsSent.V2MembershipReport }, @@ -212,7 +216,7 @@ func TestMGPDisabled(t *testing.T) { { name: "MLD", protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr, + multicastAddr: ipv6MulticastAddr1, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport }, @@ -375,7 +379,7 @@ func TestMGPJoinGroup(t *testing.T) { { name: "IGMP", protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr, + multicastAddr: ipv4MulticastAddr1, maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().IGMP.PacketsSent.V2MembershipReport @@ -384,13 +388,15 @@ func TestMGPJoinGroup(t *testing.T) { return s.Stats().IGMP.PacketsReceived.MembershipQuery }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr) + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) }, }, { name: "MLD", protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr, + multicastAddr: ipv6MulticastAddr1, maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport @@ -399,7 +405,9 @@ func TestMGPJoinGroup(t *testing.T) { return s.Stats().ICMP.V6PacketsReceived.MulticastListenerQuery }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr) + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, }, } @@ -466,7 +474,7 @@ func TestMGPLeaveGroup(t *testing.T) { { name: "IGMP", protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr, + multicastAddr: ipv4MulticastAddr1, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().IGMP.PacketsSent.V2MembershipReport }, @@ -474,16 +482,20 @@ func TestMGPLeaveGroup(t *testing.T) { return s.Stats().IGMP.PacketsSent.LeaveGroup }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr) + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) }, validateLeave: func(t *testing.T, p channel.PacketInfo) { - validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr) + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) }, }, { name: "MLD", protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr, + multicastAddr: ipv6MulticastAddr1, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport }, @@ -491,10 +503,14 @@ func TestMGPLeaveGroup(t *testing.T) { return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr) + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, validateLeave: func(t *testing.T, p channel.PacketInfo) { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr) + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, }, } @@ -557,7 +573,7 @@ func TestMGPQueryMessages(t *testing.T) { { name: "IGMP", protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr, + multicastAddr: ipv4MulticastAddr1, maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().IGMP.PacketsSent.V2MembershipReport @@ -569,14 +585,16 @@ func TestMGPQueryMessages(t *testing.T) { createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr) + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) }, maxRespTimeToDuration: header.DecisecondToDuration, }, { name: "MLD", protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr, + multicastAddr: ipv6MulticastAddr1, maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport @@ -588,7 +606,9 @@ func TestMGPQueryMessages(t *testing.T) { createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr) + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond @@ -702,7 +722,7 @@ func TestMGPReportMessages(t *testing.T) { { name: "IGMP", protoNum: ipv4.ProtocolNumber, - multicastAddr: ipv4MulticastAddr, + multicastAddr: ipv4MulticastAddr1, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().IGMP.PacketsSent.V2MembershipReport }, @@ -710,17 +730,19 @@ func TestMGPReportMessages(t *testing.T) { return s.Stats().IGMP.PacketsSent.LeaveGroup }, rxReport: func(e *channel.Endpoint) { - createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr) + createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr) + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) }, maxRespTimeToDuration: header.DecisecondToDuration, }, { name: "MLD", protoNum: ipv6.ProtocolNumber, - multicastAddr: ipv6MulticastAddr, + multicastAddr: ipv6MulticastAddr1, sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport }, @@ -728,10 +750,12 @@ func TestMGPReportMessages(t *testing.T) { return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone }, rxReport: func(e *channel.Endpoint) { - createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr) + createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) }, validateReport: func(t *testing.T, p channel.PacketInfo) { - validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr) + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond @@ -791,3 +815,254 @@ func TestMGPReportMessages(t *testing.T) { }) } } + +func TestMGPWithNICLifecycle(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddrs []tcpip.Address + finalMulticastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) + validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) + getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, + finalMulticastAddr: ipv4MulticastAddr3, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { + t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) + } + addr := header.IGMP(ipv4.Payload()).GroupAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, + finalMulticastAddr: ipv6MulticastAddr3, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, addr, mldReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber { + t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) + } + icmpv6 := header.ICMPv6(ipv6.Payload()) + if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { + t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) + } + addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, true) + + sentReportStat := test.sentReportStat(s) + var reportCounter uint64 + for _, a := range test.multicastAddrs { + if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected a report message to be sent for %s", a) + } else { + test.validateReport(t, p, a) + } + } + if t.Failed() { + t.FailNow() + } + + // Leave messages should be sent for the joined groups when the NIC is + // disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + sentLeaveStat := test.sentLeaveStat(s) + leaveCounter := uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + + test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Reports should be sent for the joined groups when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter += uint64(len(test.multicastAddrs)) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) report message to be sent", i) + } + + test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Joining/leaving a group while disabled should not send any messages. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + for i, _ := range test.multicastAddrs { + if _, ok := e.Read(); !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + } + for _, a := range test.multicastAddrs { + if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) + } + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) + } + } + if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) + } + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) + } + + // A report should only be sent for the group we last joined after + // enabling the NIC since the original groups were all left. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index adeebfe37..6e855d815 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -625,6 +625,17 @@ func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { return ok } +// JoinedGroups returns a list of groups the endpoint is a member of. +func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address { + a.mu.RLock() + defer a.mu.RUnlock() + groups := make([]tcpip.Address, 0, len(a.mu.groups)) + for g := range a.mu.groups { + groups = append(groups, g) + } + return groups +} + // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..0c8040c67 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -15,12 +15,40 @@ package stack_test import ( + "sort" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +func TestJoinedGroups(t *testing.T) { + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + + var ep fakeNetworkEndpoint + var s stack.AddressableEndpointState + s.Init(&ep) + + if joined, err := s.JoinGroup(addr1); err != nil { + t.Fatalf("JoinGroup(%s): %s", addr1, err) + } else if !joined { + t.Errorf("got JoinGroup(%s) = false, want = true", addr1) + } + if joined, err := s.JoinGroup(addr2); err != nil { + t.Fatalf("JoinGroup(%s): %s", addr2, err) + } else if !joined { + t.Errorf("got JoinGroup(%s) = false, want = true", addr2) + } + + joinedGroups := s.JoinedGroups() + sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] }) + if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" { + t.Errorf("joined groups mismatch (-want +got):\n%s", diff) + } +} + // TestAddressableEndpointStateCleanup tests that cleaning up an addressable // endpoint state removes permanent addresses and leaves groups. func TestAddressableEndpointStateCleanup(t *testing.T) { |