summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go34
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go34
-rw-r--r--pkg/tcpip/network/multicast_group_test.go327
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go11
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go28
5 files changed, 396 insertions, 38 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index be9c8e2f9..ce2087002 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -123,6 +123,15 @@ func (e *endpoint) Enable() *tcpip.Error {
// We have no need for the address endpoint.
ep.DecRef()
+ // Groups may have been joined while the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of IGMP when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
+ for _, group := range joinedGroups {
+ e.igmp.joinGroup(group)
+ }
+
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
@@ -168,6 +177,13 @@ func (e *endpoint) disableLocked() {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
+ // Leave groups from the perspective of IGMP so that routers know that
+ // we are no longer interested in the group.
+ joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
+ for _, group := range joinedGroups {
+ e.igmp.leaveGroup(group)
+ }
+
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
@@ -853,6 +869,15 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
return joined, err
}
+ // Only join the group from the perspective of IGMP when the endpoint is
+ // enabled.
+ //
+ // If we are not enabled right now, we will join the group from the
+ // perspective of IGMP when the endpoint is enabled.
+ if !e.Enabled() {
+ return true, nil
+ }
+
// joinGroup only returns an error if we try to join a group twice, but we
// checked above to make sure that the group was newly joined.
if err := e.igmp.joinGroup(addr); err != nil {
@@ -874,15 +899,12 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
// Precondition: e.mu must be locked.
func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
left, err := e.mu.addressableEndpointState.LeaveGroup(addr)
- if err != nil {
+ if err != nil || !left {
return left, err
}
- if left {
- e.igmp.leaveGroup(addr)
- }
-
- return left, nil
+ e.igmp.leaveGroup(addr)
+ return true, nil
}
// IsInGroup implements stack.GroupAddressableEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index ac67d4ac5..4d49afcbb 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -228,6 +228,15 @@ func (e *endpoint) Enable() *tcpip.Error {
return nil
}
+ // Groups may have been joined when the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of MLD when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
+ for _, group := range joinedGroups {
+ e.mld.joinGroup(group)
+ }
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -338,6 +347,13 @@ func (e *endpoint) disableLocked() {
if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
+
+ // Leave groups from the perspective of MLD so that routers know that
+ // we are no longer interested in the group.
+ joinedGroups := e.mu.addressableEndpointState.JoinedGroups()
+ for _, group := range joinedGroups {
+ e.mld.leaveGroup(group)
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -1409,6 +1425,15 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
return joined, err
}
+ // Only join the group from the perspective of IGMP when the endpoint is
+ // enabled.
+ //
+ // If we are not enabled right now, we will join the group from the
+ // perspective of MLD when the endpoint is enabled.
+ if !e.Enabled() {
+ return true, nil
+ }
+
// joinGroup only returns an error if we try to join a group twice, but we
// checked above to make sure that the group was newly joined.
if err := e.mld.joinGroup(addr); err != nil {
@@ -1430,15 +1455,12 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
// Precondition: e.mu must be locked.
func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) {
left, err := e.mu.addressableEndpointState.LeaveGroup(addr)
- if err != nil {
+ if err != nil || !left {
return left, err
}
- if left {
- e.mld.leaveGroup(addr)
- }
-
- return left, nil
+ e.mld.leaveGroup(addr)
+ return true, nil
}
// IsInGroup implements stack.GroupAddressableEndpoint.
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 12f9fbe82..d3517c364 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -34,8 +34,12 @@ import (
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- ipv4MulticastAddr = tcpip.Address("\xe0\x00\x00\x03")
- ipv6MulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
+ ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
+ ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
+ ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
+ ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
@@ -97,9 +101,9 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
t.Helper()
- // Create an endpoint of queue size 1, since no more than 1 packets are ever
+ // Create an endpoint of queue size 2, since no more than 2 packets are ever
// queued in the tests in this file.
- e := channel.New(1, 1280, linkAddr)
+ e := channel.New(2, 1280, linkAddr)
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -198,7 +202,7 @@ func TestMGPDisabled(t *testing.T) {
{
name: "IGMP",
protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr,
+ multicastAddr: ipv4MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().IGMP.PacketsSent.V2MembershipReport
},
@@ -212,7 +216,7 @@ func TestMGPDisabled(t *testing.T) {
{
name: "MLD",
protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr,
+ multicastAddr: ipv6MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
},
@@ -375,7 +379,7 @@ func TestMGPJoinGroup(t *testing.T) {
{
name: "IGMP",
protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr,
+ multicastAddr: ipv4MulticastAddr1,
maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().IGMP.PacketsSent.V2MembershipReport
@@ -384,13 +388,15 @@ func TestMGPJoinGroup(t *testing.T) {
return s.Stats().IGMP.PacketsReceived.MembershipQuery
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
},
},
{
name: "MLD",
protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr,
+ multicastAddr: ipv6MulticastAddr1,
maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
@@ -399,7 +405,9 @@ func TestMGPJoinGroup(t *testing.T) {
return s.Stats().ICMP.V6PacketsReceived.MulticastListenerQuery
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
},
},
}
@@ -466,7 +474,7 @@ func TestMGPLeaveGroup(t *testing.T) {
{
name: "IGMP",
protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr,
+ multicastAddr: ipv4MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().IGMP.PacketsSent.V2MembershipReport
},
@@ -474,16 +482,20 @@ func TestMGPLeaveGroup(t *testing.T) {
return s.Stats().IGMP.PacketsSent.LeaveGroup
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
},
validateLeave: func(t *testing.T, p channel.PacketInfo) {
- validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr)
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
},
},
{
name: "MLD",
protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr,
+ multicastAddr: ipv6MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
},
@@ -491,10 +503,14 @@ func TestMGPLeaveGroup(t *testing.T) {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
},
validateLeave: func(t *testing.T, p channel.PacketInfo) {
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr)
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
},
},
}
@@ -557,7 +573,7 @@ func TestMGPQueryMessages(t *testing.T) {
{
name: "IGMP",
protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr,
+ multicastAddr: ipv4MulticastAddr1,
maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().IGMP.PacketsSent.V2MembershipReport
@@ -569,14 +585,16 @@ func TestMGPQueryMessages(t *testing.T) {
createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
},
maxRespTimeToDuration: header.DecisecondToDuration,
},
{
name: "MLD",
protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr,
+ multicastAddr: ipv6MulticastAddr1,
maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
@@ -588,7 +606,9 @@ func TestMGPQueryMessages(t *testing.T) {
createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
},
maxRespTimeToDuration: func(d uint8) time.Duration {
return time.Duration(d) * time.Millisecond
@@ -702,7 +722,7 @@ func TestMGPReportMessages(t *testing.T) {
{
name: "IGMP",
protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr,
+ multicastAddr: ipv4MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().IGMP.PacketsSent.V2MembershipReport
},
@@ -710,17 +730,19 @@ func TestMGPReportMessages(t *testing.T) {
return s.Stats().IGMP.PacketsSent.LeaveGroup
},
rxReport: func(e *channel.Endpoint) {
- createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateIGMPPacket(t, p, ipv4MulticastAddr, igmpv2MembershipReport, 0, ipv4MulticastAddr)
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
},
maxRespTimeToDuration: header.DecisecondToDuration,
},
{
name: "MLD",
protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr,
+ multicastAddr: ipv6MulticastAddr1,
sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
},
@@ -728,10 +750,12 @@ func TestMGPReportMessages(t *testing.T) {
return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone
},
rxReport: func(e *channel.Endpoint) {
- createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr)
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
},
validateReport: func(t *testing.T, p channel.PacketInfo) {
- validateMLDPacket(t, p, ipv6MulticastAddr, mldReport, 0, ipv6MulticastAddr)
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
},
maxRespTimeToDuration: func(d uint8) time.Duration {
return time.Duration(d) * time.Millisecond
@@ -791,3 +815,254 @@ func TestMGPReportMessages(t *testing.T) {
})
}
}
+
+func TestMGPWithNICLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddrs []tcpip.Address
+ finalMulticastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
+ validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
+ getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
+ finalMulticastAddr: ipv4MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
+ t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
+ }
+ addr := header.IGMP(ipv4.Payload()).GroupAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
+ finalMulticastAddr: ipv6MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, addr, mldReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber {
+ t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
+ }
+ icmpv6 := header.ICMPv6(ipv6.Payload())
+ if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
+ t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
+ }
+ addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ sentReportStat := test.sentReportStat(s)
+ var reportCounter uint64
+ for _, a := range test.multicastAddrs {
+ if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected a report message to be sent for %s", a)
+ } else {
+ test.validateReport(t, p, a)
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leave messages should be sent for the joined groups when the NIC is
+ // disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ sentLeaveStat := test.sentLeaveStat(s)
+ leaveCounter := uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+
+ test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Reports should be sent for the joined groups when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter += uint64(len(test.multicastAddrs))
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i, _ := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) report message to be sent", i)
+ }
+
+ test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Joining/leaving a group while disabled should not send any messages.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ for i, _ := range test.multicastAddrs {
+ if _, ok := e.Read(); !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+ }
+ for _, a := range test.multicastAddrs {
+ if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
+ }
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
+ }
+ }
+ if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
+ }
+
+ // A report should only be sent for the group we last joined after
+ // enabling the NIC since the original groups were all left.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index adeebfe37..6e855d815 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -625,6 +625,17 @@ func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
return ok
}
+// JoinedGroups returns a list of groups the endpoint is a member of.
+func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ groups := make([]tcpip.Address, 0, len(a.mu.groups))
+ for g := range a.mu.groups {
+ groups = append(groups, g)
+ }
+ return groups
+}
+
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..0c8040c67 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -15,12 +15,40 @@
package stack_test
import (
+ "sort"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+func TestJoinedGroups(t *testing.T) {
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+
+ var ep fakeNetworkEndpoint
+ var s stack.AddressableEndpointState
+ s.Init(&ep)
+
+ if joined, err := s.JoinGroup(addr1); err != nil {
+ t.Fatalf("JoinGroup(%s): %s", addr1, err)
+ } else if !joined {
+ t.Errorf("got JoinGroup(%s) = false, want = true", addr1)
+ }
+ if joined, err := s.JoinGroup(addr2); err != nil {
+ t.Fatalf("JoinGroup(%s): %s", addr2, err)
+ } else if !joined {
+ t.Errorf("got JoinGroup(%s) = false, want = true", addr2)
+ }
+
+ joinedGroups := s.JoinedGroups()
+ sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] })
+ if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" {
+ t.Errorf("joined groups mismatch (-want +got):\n%s", diff)
+ }
+}
+
// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
// endpoint state removes permanent addresses and leaves groups.
func TestAddressableEndpointStateCleanup(t *testing.T) {