diff options
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 398 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol_test.go | 456 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 82 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 75 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 50 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 8 |
10 files changed, 791 insertions, 493 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index 3113f4bbe..f14e2a88a 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -31,12 +31,15 @@ type hostState int // (RFC 2710 section 5). Even though the states are generic across both IGMPv2 // and MLDv1, IGMPv2 terminology will be used. const ( - // "'Non-Member' state, when the host does not belong to the group on - // the interface. This is the initial state for all memberships on + // nonMember is the "'Non-Member' state, when the host does not belong to the + // group on the interface. This is the initial state for all memberships on // all network interfaces; it requires no storage in the host." // // 'Non-Listener' is the MLDv1 term used to describe this state. - _ hostState = iota + // + // This state is used to keep track of groups that have been joined locally, + // but without advertising the membership to the network. + nonMember hostState = iota // delayingMember is the "'Delaying Member' state, when the host belongs to // the group on the interface and has a report delay timer running for that @@ -56,7 +59,10 @@ const ( // multicastGroupState holds the Generic Multicast Protocol state for a // multicast group. type multicastGroupState struct { - // state contains the host's state for the group. + // joins is the number of times the group has been joined. + joins uint64 + + // state holds the host's state for the group. state hostState // lastToSendReport is true if we sent the last report for the group. It is @@ -75,6 +81,45 @@ type multicastGroupState struct { delayedReportJob *tcpip.Job } +// GenericMulticastProtocolOptions holds options for the generic multicast +// protocol. +type GenericMulticastProtocolOptions struct { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled bool + + // Rand is the source of random numbers. + Rand *rand.Rand + + // Clock is the clock used to create timers. + Clock tcpip.Clock + + // Protocol is the implementation of the variant of multicast group protocol + // in use. + Protocol MulticastGroupProtocol + + // MaxUnsolicitedReportDelay is the maximum amount of time to wait between + // transmitting unsolicited reports. + // + // Unsolicited reports are transmitted when a group is newly joined. + MaxUnsolicitedReportDelay time.Duration + + // AllNodesAddress is a multicast address that all nodes on a network should + // be a member of. + // + // This address will not have the generic multicast protocol performed on it; + // it will be left in the non member/listener state, and packets will never + // be sent for it. + AllNodesAddress tcpip.Address +} + // MulticastGroupProtocol is a multicast group protocol whose core state machine // can be represented by GenericMulticastProtocolState. type MulticastGroupProtocol interface { @@ -96,13 +141,10 @@ type MulticastGroupProtocol interface { // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. type GenericMulticastProtocolState struct { - rand *rand.Rand - clock tcpip.Clock - protocol MulticastGroupProtocol - maxUnsolicitedReportDelay time.Duration + opts GenericMulticastProtocolOptions mu struct { - sync.Mutex + sync.RWMutex // memberships holds group addresses and their associated state. memberships map[tcpip.Address]multicastGroupState @@ -113,130 +155,123 @@ type GenericMulticastProtocolState struct { // // maxUnsolicitedReportDelay is the maximum time between sending unsolicited // reports after joining a group. -func (g *GenericMulticastProtocolState) Init(rand *rand.Rand, clock tcpip.Clock, protocol MulticastGroupProtocol, maxUnsolicitedReportDelay time.Duration) { +func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { g.mu.Lock() defer g.mu.Unlock() - g.rand = rand - g.clock = clock - g.protocol = protocol - g.maxUnsolicitedReportDelay = maxUnsolicitedReportDelay + g.opts = opts g.mu.memberships = make(map[tcpip.Address]multicastGroupState) } +// MakeAllNonMember transitions all groups to the non-member state. +// +// The groups will still be considered joined locally. +func (g *GenericMulticastProtocolState) MakeAllNonMember() { + if !g.opts.Enabled { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + for groupAddress, info := range g.mu.memberships { + g.transitionToNonMemberLocked(groupAddress, &info) + g.mu.memberships[groupAddress] = info + } +} + +// InitializeGroups initializes each group, as if they were newly joined but +// without affecting the groups' join count. +// +// Must only be called after calling MakeAllNonMember as a group should not be +// initialized while it is not in the non-member state. +func (g *GenericMulticastProtocolState) InitializeGroups() { + if !g.opts.Enabled { + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + for groupAddress, info := range g.mu.memberships { + g.initializeNewMemberLocked(groupAddress, &info) + g.mu.memberships[groupAddress] = info + } +} + // JoinGroup handles joining a new group. // -// Returns false if the group has already been joined. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address) bool { +// If dontInitialize is true, the group will be not be initialized and will be +// left in the non-member state - no packets will be sent for it until it is +// initialized via InitializeGroups. +func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { g.mu.Lock() defer g.mu.Unlock() - if _, ok := g.mu.memberships[groupAddress]; ok { + if info, ok := g.mu.memberships[groupAddress]; ok { // The group has already been joined. - return false + info.joins++ + g.mu.memberships[groupAddress] = info + return } info := multicastGroupState{ - // There isn't a job scheduled currently, so it's just idle. - state: idleMember, - // Joining a group immediately sends a report. - lastToSendReport: true, - delayedReportJob: tcpip.NewJob(g.clock, &g.mu, func() { + // Since we just joined the group, its count is 1. + joins: 1, + // The state will be updated below, if required. + state: nonMember, + lastToSendReport: false, + delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { info, ok := g.mu.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } - info.lastToSendReport = g.protocol.SendReport(groupAddress) == nil + info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil info.state = idleMember g.mu.memberships[groupAddress] = info }), } - // As per RFC 2236 section 3 page 5 (for IGMPv2), - // - // When a host joins a multicast group, it should immediately transmit an - // unsolicited Version 2 Membership Report for that group" ... "it is - // recommended that it be repeated". - // - // As per RFC 2710 section 4 page 6 (for MLDv1), - // - // When a node starts listening to a multicast address on an interface, - // it should immediately transmit an unsolicited Report for that address - // on that interface, in case it is the first listener on the link. To - // cover the possibility of the initial Report being lost or damaged, it - // is recommended that it be repeated once or twice after short delays - // [Unsolicited Report Interval]. - // - // TODO(gvisor.dev/issue/4901): Support a configurable number of initial - // unsolicited reports. - info.lastToSendReport = g.protocol.SendReport(groupAddress) == nil - g.setDelayTimerForAddressRLocked(groupAddress, &info, g.maxUnsolicitedReportDelay) + if !dontInitialize && g.opts.Enabled { + g.initializeNewMemberLocked(groupAddress, &info) + } + g.mu.memberships[groupAddress] = info - return true +} + +// IsLocallyJoined returns true if the group is locally joined. +func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { + g.mu.RLock() + defer g.mu.RUnlock() + _, ok := g.mu.memberships[groupAddress] + return ok } // LeaveGroup handles leaving the group. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) { +// +// Returns false if the group is not currently joined. +func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { g.mu.Lock() defer g.mu.Unlock() info, ok := g.mu.memberships[groupAddress] if !ok { - return + return false } - info.delayedReportJob.Cancel() - delete(g.mu.memberships, groupAddress) - if info.lastToSendReport { - // Okay to ignore the error here as if packet write failed, the multicast - // routers will eventually drop our membership anyways. If the interface is - // being disabled or removed, the generic multicast protocol's should be - // cleared eventually. - // - // As per RFC 2236 section 3 page 5 (for IGMPv2), - // - // When a router receives a Report, it adds the group being reported to - // the list of multicast group memberships on the network on which it - // received the Report and sets the timer for the membership to the - // [Group Membership Interval]. Repeated Reports refresh the timer. If - // no Reports are received for a particular group before this timer has - // expired, the router assumes that the group has no local members and - // that it need not forward remotely-originated multicasts for that - // group onto the attached network. - // - // As per RFC 2710 section 4 page 5 (for MLDv1), - // - // When a router receives a Report from a link, if the reported address - // is not already present in the router's list of multicast address - // having listeners on that link, the reported address is added to the - // list, its timer is set to [Multicast Listener Interval], and its - // appearance is made known to the router's multicast routing component. - // If a Report is received for a multicast address that is already - // present in the router's list, the timer for that address is reset to - // [Multicast Listener Interval]. If an address's timer expires, it is - // assumed that there are no longer any listeners for that address - // present on the link, so it is deleted from the list and its - // disappearance is made known to the multicast routing component. - // - // The requirement to send a leave message is also optional (it MAY be - // skipped): - // - // As per RFC 2236 section 6 page 8 (for IGMPv2), - // - // "send leave" for the group on the interface. If the interface - // state says the Querier is running IGMPv1, this action SHOULD be - // skipped. If the flag saying we were the last host to report is - // cleared, this action MAY be skipped. The Leave Message is sent to - // the ALL-ROUTERS group (224.0.0.2). - // - // As per RFC 2710 section 5 page 8 (for MLDv1), - // - // "send done" for the address on the interface. If the flag saying - // we were the last node to report is cleared, this action MAY be - // skipped. The Done message is sent to the link-scope all-routers - // address (FF02::2). - _ = g.protocol.SendLeave(groupAddress) + if info.joins == 0 { + panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) + } + info.joins-- + if info.joins != 0 { + // If we still have outstanding joins, then do nothing further. + g.mu.memberships[groupAddress] = info + return true } + + g.transitionToNonMemberLocked(groupAddress, &info) + delete(g.mu.memberships, groupAddress) + return true } // HandleQuery handles a query message with the specified maximum response time. @@ -247,6 +282,10 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) { // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Enabled { + return + } + g.mu.Lock() defer g.mu.Unlock() @@ -278,6 +317,10 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { + if !g.opts.Enabled { + return + } + g.mu.Lock() defer g.mu.Unlock() @@ -293,7 +336,7 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok { + if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember @@ -301,10 +344,167 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) } } +// initializeNewMemberLocked initializes a new group membership. +// +// Precondition: g.mu must be locked. +func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != nonMember { + panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) + } + + info.state = idleMember + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a host joins a multicast group, it should immediately transmit an + // unsolicited Version 2 Membership Report for that group" ... "it is + // recommended that it be repeated". + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node starts listening to a multicast address on an interface, + // it should immediately transmit an unsolicited Report for that address + // on that interface, in case it is the first listener on the link. To + // cover the possibility of the initial Report being lost or damaged, it + // is recommended that it be repeated once or twice after short delays + // [Unsolicited Report Interval]. + // + // TODO(gvisor.dev/issue/4901): Support a configurable number of initial + // unsolicited reports. + info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) +} + +// maybeSendLeave attempts to send a leave message. +func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { + if !g.opts.Enabled || !lastToSendReport { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // Okay to ignore the error here as if packet write failed, the multicast + // routers will eventually drop our membership anyways. If the interface is + // being disabled or removed, the generic multicast protocol's should be + // cleared eventually. + // + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a router receives a Report, it adds the group being reported to + // the list of multicast group memberships on the network on which it + // received the Report and sets the timer for the membership to the + // [Group Membership Interval]. Repeated Reports refresh the timer. If + // no Reports are received for a particular group before this timer has + // expired, the router assumes that the group has no local members and + // that it need not forward remotely-originated multicasts for that + // group onto the attached network. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // When a router receives a Report from a link, if the reported address + // is not already present in the router's list of multicast address + // having listeners on that link, the reported address is added to the + // list, its timer is set to [Multicast Listener Interval], and its + // appearance is made known to the router's multicast routing component. + // If a Report is received for a multicast address that is already + // present in the router's list, the timer for that address is reset to + // [Multicast Listener Interval]. If an address's timer expires, it is + // assumed that there are no longer any listeners for that address + // present on the link, so it is deleted from the list and its + // disappearance is made known to the multicast routing component. + // + // The requirement to send a leave message is also optional (it MAY be + // skipped): + // + // As per RFC 2236 section 6 page 8 (for IGMPv2), + // + // "send leave" for the group on the interface. If the interface + // state says the Querier is running IGMPv1, this action SHOULD be + // skipped. If the flag saying we were the last host to report is + // cleared, this action MAY be skipped. The Leave Message is sent to + // the ALL-ROUTERS group (224.0.0.2). + // + // As per RFC 2710 section 5 page 8 (for MLDv1), + // + // "send done" for the address on the interface. If the flag saying + // we were the last node to report is cleared, this action MAY be + // skipped. The Done message is sent to the link-scope all-routers + // address (FF02::2). + _ = g.opts.Protocol.SendLeave(groupAddress) +} + +// transitionToNonMemberLocked transitions the given multicast group the the +// non-member/listener state. +// +// Precondition: e.mu must be locked. +func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state == nonMember { + return + } + + info.delayedReportJob.Cancel() + g.maybeSendLeave(groupAddress, info.lastToSendReport) + info.lastToSendReport = false + info.state = nonMember +} + // setDelayTimerForAddressRLocked sets timer to send a delay report. // // Precondition: g.mu MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { + if info.state == nonMember { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + // As per RFC 2236 section 3 page 3 (for IGMPv2), // // If a timer for the group is already unning, it is reset to the random @@ -342,5 +542,5 @@ func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime if maxRespTime == 0 { return 0 } - return time.Duration(g.rand.Int63n(int64(maxRespTime))) + return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) } diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index eb48c0d51..670be30d4 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -29,18 +29,21 @@ const ( addr1 = tcpip.Address("\x01") addr2 = tcpip.Address("\x02") addr3 = tcpip.Address("\x03") + addr4 = tcpip.Address("\x04") + + maxUnsolicitedReportDelay = time.Second ) var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) type mockMulticastGroupProtocol struct { sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddr tcpip.Address + sendLeaveGroupAddrCount map[tcpip.Address]int } func (m *mockMulticastGroupProtocol) init() { m.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.sendLeaveGroupAddr = "" + m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { @@ -49,87 +52,146 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcp } func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - m.sendLeaveGroupAddr = groupAddress + m.sendLeaveGroupAddrCount[groupAddress]++ return nil } -func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddr tcpip.Address) string { +func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { sendReportGroupAddressesMap := make(map[tcpip.Address]int) for _, a := range sendReportGroupAddresses { sendReportGroupAddressesMap[a] = 1 } + sendLeaveGroupAddressesMap := make(map[tcpip.Address]int) + for _, a := range sendLeaveGroupAddresses { + sendLeaveGroupAddressesMap[a] = 1 + } + diff := cmp.Diff(mockMulticastGroupProtocol{ sendReportGroupAddrCount: sendReportGroupAddressesMap, - sendLeaveGroupAddr: sendLeaveGroupAddr, + sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap, }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{})) mgp.init() return diff } func TestJoinGroup(t *testing.T) { - const maxUnsolicitedReportDelay = time.Second - - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() - clock := faketime.NewManualClock() - g.Init(rand.New(rand.NewSource(0)), clock, &mgp, maxUnsolicitedReportDelay) - - // Joining a group should send a report immediately and another after - // a random interval between 0 and the maximum unsolicited report delay. - if !g.JoinGroup(addr1) { - t.Errorf("got g.JoinGroup(%s) = false, want = true", addr1) - } - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + tests := []struct { + name string + addr tcpip.Address + shouldSendReports bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendReports: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendReports: false, + }, } - clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(0)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + // Joining a group should send a report immediately and another after + // a random interval between 0 and the maximum unsolicited report delay. + g.JoinGroup(test.addr, false /* dontInitialize */) + if test.shouldSendReports { + if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + clock.Advance(maxUnsolicitedReportDelay) + if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) } } func TestLeaveGroup(t *testing.T) { - const maxUnsolicitedReportDelay = time.Second - - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() - clock := faketime.NewManualClock() - g.Init(rand.New(rand.NewSource(1)), clock, &mgp, maxUnsolicitedReportDelay) - - if !g.JoinGroup(addr1) { - t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1) - } - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + tests := []struct { + name string + addr tcpip.Address + shouldSendMessages bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendMessages: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendMessages: false, + }, } - // Leaving a group should send a leave report immediately and cancel any - // delayed reports. - g.LeaveGroup(addr1) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, addr1 /* sendLeaveGroupAddr */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(1)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + g.JoinGroup(test.addr, false /* dontInitialize */) + if test.shouldSendMessages { + if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } - // Should have no more messages to send. - clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + // Leaving a group should send a leave report immediately and cancel any + // delayed reports. + if !g.LeaveGroup(test.addr) { + t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr) + } + if test.shouldSendMessages { + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) } } func TestHandleReport(t *testing.T) { - const maxUnsolicitedReportDelay = time.Second - tests := []struct { name string reportAddr tcpip.Address @@ -151,10 +213,15 @@ func TestHandleReport(t *testing.T) { expectReportsFor: []tcpip.Address{addr2}, }, { - name: "Specified other", + name: "Specified all-nodes", reportAddr: addr3, expectReportsFor: []tcpip.Address{addr1, addr2}, }, + { + name: "Specified other", + reportAddr: addr4, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, } for _, test := range tests { @@ -163,18 +230,25 @@ func TestHandleReport(t *testing.T) { var mgp mockMulticastGroupProtocol mgp.init() clock := faketime.NewManualClock() - g.Init(rand.New(rand.NewSource(2)), clock, &mgp, maxUnsolicitedReportDelay) - - if !g.JoinGroup(addr1) { - t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1) - } - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(2)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + g.JoinGroup(addr1, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if !g.JoinGroup(addr2) { - t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr2) + g.JoinGroup(addr2, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + g.JoinGroup(addr3, false /* dontInitialize */) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -183,14 +257,14 @@ func TestHandleReport(t *testing.T) { g.HandleReport(test.reportAddr) if len(test.expectReportsFor) != 0 { clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -198,8 +272,6 @@ func TestHandleReport(t *testing.T) { } func TestHandleQuery(t *testing.T) { - const maxUnsolicitedReportDelay = time.Second - tests := []struct { name string queryAddr tcpip.Address @@ -225,11 +297,17 @@ func TestHandleQuery(t *testing.T) { expectReportsFor: []tcpip.Address{addr1}, }, { - name: "Specified other", + name: "Specified all-nodes", queryAddr: addr3, maxDelay: 3, expectReportsFor: nil, }, + { + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectReportsFor: nil, + }, } for _, test := range tests { @@ -238,22 +316,29 @@ func TestHandleQuery(t *testing.T) { var mgp mockMulticastGroupProtocol mgp.init() clock := faketime.NewManualClock() - g.Init(rand.New(rand.NewSource(3)), clock, &mgp, maxUnsolicitedReportDelay) - - if !g.JoinGroup(addr1) { - t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1) - } - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + g.JoinGroup(addr1, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if !g.JoinGroup(addr2) { - t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr2) + g.JoinGroup(addr2, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + g.JoinGroup(addr3, false /* dontInitialize */) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } @@ -262,33 +347,230 @@ func TestHandleQuery(t *testing.T) { g.HandleQuery(test.queryAddr, test.maxDelay) if len(test.expectReportsFor) != 0 { clock.Advance(test.maxDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, "" /* sendLeaveGroupAddr */); diff != "" { + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) } } -func TestDoubleJoinGroup(t *testing.T) { +func TestJoinCount(t *testing.T) { var g ip.GenericMulticastProtocolState var mgp mockMulticastGroupProtocol mgp.init() clock := faketime.NewManualClock() - g.Init(rand.New(rand.NewSource(4)), clock, &mgp, time.Second) + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: time.Second, + }) + + // Set the join count to 2 for a group. + g.JoinGroup(addr1, false /* dontInitialize */) + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + // Only the first join should trigger a report to be sent. + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr1, false /* dontInitialize */) + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Group should still be considered joined after leaving once. + if !g.LeaveGroup(addr1) { + t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + } + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + // A leave report should only be sent once the join count reaches 0. + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Leaving once more should actually remove us from the group. + if !g.LeaveGroup(addr1) { + t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + } + if g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Group should no longer be joined so we should not have anything to + // leave. + if g.LeaveGroup(addr1) { + t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1) + } + if g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestMakeAllNonMemberAndInitialize(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + g.JoinGroup(addr1, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr2, false /* dontInitialize */) + if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + g.JoinGroup(addr3, false /* dontInitialize */) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should send the leave reports for each but still consider them locally + // joined. + g.MakeAllNonMember() + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + for _, group := range []tcpip.Address{addr1, addr2, addr3} { + if !g.IsLocallyJoined(group) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group) + } + } - if !g.JoinGroup(addr1) { - t.Fatalf("got g.JoinGroup(%s) = false, want = true", addr1) + // Should send the initial set of unsolcited reports. + g.InitializeGroups() + if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +// TestGroupStateNonMember tests that groups do not send packets when in the +// non-member state, but are still considered locally joined. +func TestGroupStateNonMember(t *testing.T) { + tests := []struct { + name string + enabled bool + dontInitialize bool + }{ + { + name: "Disabled", + enabled: false, + dontInitialize: false, + }, + { + name: "Keep non-member", + enabled: true, + dontInitialize: true, + }, + { + name: "disabled and Keep non-member", + enabled: false, + dontInitialize: true, + }, } - // Joining the same group twice should fail. - if g.JoinGroup(addr1) { - t.Errorf("got g.JoinGroup(%s) = true, want = false", addr1) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(ip.GenericMulticastProtocolOptions{ + Enabled: test.enabled, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + g.JoinGroup(addr1, test.dontInitialize) + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + g.JoinGroup(addr2, test.dontInitialize) + if !g.IsLocallyJoined(addr2) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + g.HandleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + if !g.LeaveGroup(addr2) { + t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2) + } + if !g.IsLocallyJoined(addr1) { + t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + } + if g.IsLocallyJoined(addr2) { + t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2) + } + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(time.Hour) + if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 37f1822ca..18ccd28c3 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -124,7 +124,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { defer igmp.mu.Unlock() igmp.ep = ep igmp.opts = opts - igmp.mu.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), igmp, UnsolicitedReportIntervalMax) + igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ + Enabled: opts.Enabled, + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: igmp, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv4AllSystems, + }) igmp.igmpV1Present = igmpV1PresentDefault igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { igmp.setV1Present(false) @@ -201,17 +208,13 @@ func (igmp *igmpState) setV1Present(v bool) { } func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { - if !igmp.opts.Enabled { - return - } - igmp.mu.Lock() defer igmp.mu.Unlock() // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero // then change the state to note that an IGMPv1 router is present and // schedule the query received Job. - if maxRespTime == 0 { + if maxRespTime == 0 && igmp.opts.Enabled { igmp.mu.igmpV1Job.Cancel() igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) igmp.setV1Present(true) @@ -222,10 +225,6 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp } func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { - if !igmp.opts.Enabled { - return - } - igmp.mu.Lock() defer igmp.mu.Unlock() igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) @@ -279,49 +278,46 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // // If the group already exists in the membership map, returns // tcpip.ErrDuplicateAddress. -func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { - if !igmp.opts.Enabled { - return nil - } - - // As per RFC 2236 section 6 page 10, - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // This is equivalent to not performing IGMP for the all-systems multicast - // address. Simply not performing IGMP when the group is added will prevent - // any work from being done on the all-systems multicast group when leaving - // the group or when query or report messages are received for it since the - // MGP state will not know about it. - if groupAddress == header.IPv4AllSystems { - return nil - } - +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { igmp.mu.Lock() defer igmp.mu.Unlock() + igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) +} - // JoinGroup returns false if we have already joined the group. - if !igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress) { - return tcpip.ErrDuplicateAddress - } - return nil +// isInGroup returns true if the specified group has been joined locally. +func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { + igmp.mu.Lock() + defer igmp.mu.Unlock() + return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Leave Group message // if required. -// -// If the group does not exist in the membership map, this function will -// silently return. -func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) { - if !igmp.opts.Enabled { - return +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + igmp.mu.Lock() + defer igmp.mu.Unlock() + + // LeaveGroup returns false only if the group was not joined. + if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) { + return nil } + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of IGMP, but remains +// joined locally. +func (igmp *igmpState) softLeaveAll() { + igmp.mu.Lock() + defer igmp.mu.Unlock() + igmp.mu.genericMulticastProtocol.MakeAllNonMember() +} + +// initializeAll attemps to initialize the IGMP state for each group that has +// been joined locally. +func (igmp *igmpState) initializeAll() { igmp.mu.Lock() defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) + igmp.mu.genericMulticastProtocol.InitializeGroups() } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ce2087002..354ac1e1d 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -127,16 +127,18 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of IGMP when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - joinedGroups := e.mu.addressableEndpointState.JoinedGroups() - for _, group := range joinedGroups { - e.igmp.joinGroup(group) - } + e.igmp.initializeAll() // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.joinGroupLocked(header.IPv4AllSystems) - return err + if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err)) + } + + return nil } // Enabled implements stack.NetworkEndpoint. @@ -173,16 +175,13 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if _, err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } // Leave groups from the perspective of IGMP so that routers know that // we are no longer interested in the group. - joinedGroups := e.mu.addressableEndpointState.JoinedGroups() - for _, group := range joinedGroups { - e.igmp.leaveGroup(group) - } + e.igmp.softLeaveAll() // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { @@ -849,69 +848,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) } -// joinGroupLocked is like JoinGroup, but with locking requirements. +// joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return false, tcpip.ErrBadAddress - } - // TODO(gvisor.dev/issue/4916): Keep track of join count and IGMP state in a - // single type. - joined, err := e.mu.addressableEndpointState.JoinGroup(addr) - if err != nil || !joined { - return joined, err - } - - // Only join the group from the perspective of IGMP when the endpoint is - // enabled. - // - // If we are not enabled right now, we will join the group from the - // perspective of IGMP when the endpoint is enabled. - if !e.Enabled() { - return true, nil - } - - // joinGroup only returns an error if we try to join a group twice, but we - // checked above to make sure that the group was newly joined. - if err := e.igmp.joinGroup(addr); err != nil { - panic(fmt.Sprintf("e.igmp.joinGroup(%s): %s", addr, err)) + return tcpip.ErrBadAddress } - return true, nil + e.igmp.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) } -// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { - left, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err != nil || !left { - return left, err - } - - e.igmp.leaveGroup(addr) - return true, nil +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4d49afcbb..084c38455 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -232,10 +232,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of MLD when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - joinedGroups := e.mu.addressableEndpointState.JoinedGroups() - for _, group := range joinedGroups { - e.mld.joinGroup(group) - } + e.mld.initializeAll() // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives @@ -254,8 +251,10 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { - return err + if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err)) } // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent @@ -344,16 +343,13 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. - joinedGroups := e.mu.addressableEndpointState.JoinedGroups() - for _, group := range joinedGroups { - e.mld.leaveGroup(group) - } + e.mld.softLeaveAll() } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1182,8 +1178,10 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.joinGroupLocked(snmc); err != nil { - return nil, err + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) } addressEndpoint.SetKind(stack.PermanentTentative) @@ -1239,7 +1237,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + // The endpoint may have already left the multicast group. + if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1404,70 +1403,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) } -// joinGroupLocked is like JoinGroup, but with locking requirements. +// joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return false, tcpip.ErrBadAddress - } - - // TODO(gvisor.dev/issue/4916): Keep track of join count and MLD state in a - // single type. - joined, err := e.mu.addressableEndpointState.JoinGroup(addr) - if err != nil || !joined { - return joined, err - } - - // Only join the group from the perspective of IGMP when the endpoint is - // enabled. - // - // If we are not enabled right now, we will join the group from the - // perspective of MLD when the endpoint is enabled. - if !e.Enabled() { - return true, nil - } - - // joinGroup only returns an error if we try to join a group twice, but we - // checked above to make sure that the group was newly joined. - if err := e.mld.joinGroup(addr); err != nil { - panic(fmt.Sprintf("e.mld.joinGroup(%s): %s", addr, err)) + return tcpip.ErrBadAddress } - return true, nil + e.mld.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) } -// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { - left, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err != nil || !left { - return left, err - } - - e.mld.leaveGroup(addr) - return true, nil +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index b16a1afb0..560c5e01e 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -50,8 +50,7 @@ var _ ip.MulticastGroupProtocol = (*mldState)(nil) // mldState.init MUST be called to initialize the MLD state. type mldState struct { // The IPv6 endpoint this mldState is for. - ep *endpoint - opts MLDOptions + ep *endpoint genericMulticastProtocol ip.GenericMulticastProtocolState } @@ -70,23 +69,21 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // a new mldState. func (mld *mldState) init(ep *endpoint, opts MLDOptions) { mld.ep = ep - mld.opts = opts - mld.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), mld, UnsolicitedReportIntervalMax) + mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ + Enabled: opts.Enabled, + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: mld, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv6AllNodesMulticastAddress, + }) } func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { - if !mld.opts.Enabled { - return - } - mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) } func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { - if !mld.opts.Enabled { - return - } - mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) } @@ -94,45 +91,37 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { // messages. // // If the group is already joined, returns tcpip.ErrDuplicateAddress. -func (mld *mldState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { - if !mld.opts.Enabled { - return nil - } - - // As per RFC 2710 section 5 page 10, - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. - // - // This is equivalent to not performing MLD for the all-nodes multicast - // address. Simply not performing MLD when the group is added will prevent - // any work from being done on the all-nodes multicast group when leaving the - // group or when query or report messages are received for it since the MGP - // state will not know about it. - if groupAddress == header.IPv6AllNodesMulticastAddress { - return nil - } +func (mld *mldState) joinGroup(groupAddress tcpip.Address) { + mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) +} - // JoinGroup returns false if we have already joined the group. - if !mld.genericMulticastProtocol.JoinGroup(groupAddress) { - return tcpip.ErrDuplicateAddress - } - return nil +// isInGroup returns true if the specified group has been joined locally. +func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { + return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Done message, if // required. -// -// If the group is not joined, this function will do nothing. -func (mld *mldState) leaveGroup(groupAddress tcpip.Address) { - if !mld.opts.Enabled { - return +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + return nil } - mld.genericMulticastProtocol.LeaveGroup(groupAddress) + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of MLD, but remains +// joined locally. +func (mld *mldState) softLeaveAll() { + mld.genericMulticastProtocol.MakeAllNonMember() +} + +// initializeAll attemps to initialize the MLD state for each group that has +// been joined locally. +func (mld *mldState) initializeAll() { + mld.genericMulticastProtocol.InitializeGroups() } func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 6e855d815..b438e3acc 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,7 +48,6 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) } // ReadOnlyAddressableEndpointState provides read-only access to an @@ -335,11 +329,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -588,61 +577,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - -// JoinedGroups returns a list of groups the endpoint is a member of. -func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address { - a.mu.RLock() - defer a.mu.RUnlock() - groups := make([]tcpip.Address, 0, len(a.mu.groups)) - for g := range a.mu.groups { - groups = append(groups, g) - } - return groups -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 0c8040c67..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -15,40 +15,12 @@ package stack_test import ( - "sort" "testing" - "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func TestJoinedGroups(t *testing.T) { - const addr1 = tcpip.Address("\x01") - const addr2 = tcpip.Address("\x02") - - var ep fakeNetworkEndpoint - var s stack.AddressableEndpointState - s.Init(&ep) - - if joined, err := s.JoinGroup(addr1); err != nil { - t.Fatalf("JoinGroup(%s): %s", addr1, err) - } else if !joined { - t.Errorf("got JoinGroup(%s) = false, want = true", addr1) - } - if joined, err := s.JoinGroup(addr2); err != nil { - t.Fatalf("JoinGroup(%s): %s", addr2, err) - } else if !joined { - t.Errorf("got JoinGroup(%s) = false, want = true", addr2) - } - - joinedGroups := s.JoinedGroups() - sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] }) - if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" { - t.Errorf("joined groups mismatch (-want +got):\n%s", diff) - } -} - // TestAddressableEndpointStateCleanup tests that cleaning up an addressable // endpoint state removes permanent addresses and leaves groups. func TestAddressableEndpointStateCleanup(t *testing.T) { @@ -81,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5f3757de0..1805a8e0a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -563,8 +563,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -580,11 +579,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 43ca03ada..236e4e4c7 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -291,14 +291,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool |