diff options
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 75 |
2 files changed, 56 insertions, 95 deletions
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4d49afcbb..084c38455 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -232,10 +232,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of MLD when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - joinedGroups := e.mu.addressableEndpointState.JoinedGroups() - for _, group := range joinedGroups { - e.mld.joinGroup(group) - } + e.mld.initializeAll() // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives @@ -254,8 +251,10 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { - return err + if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err)) } // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent @@ -344,16 +343,13 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. - joinedGroups := e.mu.addressableEndpointState.JoinedGroups() - for _, group := range joinedGroups { - e.mld.leaveGroup(group) - } + e.mld.softLeaveAll() } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1182,8 +1178,10 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.joinGroupLocked(snmc); err != nil { - return nil, err + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) } addressEndpoint.SetKind(stack.PermanentTentative) @@ -1239,7 +1237,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + // The endpoint may have already left the multicast group. + if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1404,70 +1403,43 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) } -// joinGroupLocked is like JoinGroup, but with locking requirements. +// joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return false, tcpip.ErrBadAddress - } - - // TODO(gvisor.dev/issue/4916): Keep track of join count and MLD state in a - // single type. - joined, err := e.mu.addressableEndpointState.JoinGroup(addr) - if err != nil || !joined { - return joined, err - } - - // Only join the group from the perspective of IGMP when the endpoint is - // enabled. - // - // If we are not enabled right now, we will join the group from the - // perspective of MLD when the endpoint is enabled. - if !e.Enabled() { - return true, nil - } - - // joinGroup only returns an error if we try to join a group twice, but we - // checked above to make sure that the group was newly joined. - if err := e.mld.joinGroup(addr); err != nil { - panic(fmt.Sprintf("e.mld.joinGroup(%s): %s", addr, err)) + return tcpip.ErrBadAddress } - return true, nil + e.mld.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) } -// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { - left, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err != nil || !left { - return left, err - } - - e.mld.leaveGroup(addr) - return true, nil +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index b16a1afb0..560c5e01e 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -50,8 +50,7 @@ var _ ip.MulticastGroupProtocol = (*mldState)(nil) // mldState.init MUST be called to initialize the MLD state. type mldState struct { // The IPv6 endpoint this mldState is for. - ep *endpoint - opts MLDOptions + ep *endpoint genericMulticastProtocol ip.GenericMulticastProtocolState } @@ -70,23 +69,21 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // a new mldState. func (mld *mldState) init(ep *endpoint, opts MLDOptions) { mld.ep = ep - mld.opts = opts - mld.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), mld, UnsolicitedReportIntervalMax) + mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ + Enabled: opts.Enabled, + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: mld, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv6AllNodesMulticastAddress, + }) } func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { - if !mld.opts.Enabled { - return - } - mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) } func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { - if !mld.opts.Enabled { - return - } - mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) } @@ -94,45 +91,37 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { // messages. // // If the group is already joined, returns tcpip.ErrDuplicateAddress. -func (mld *mldState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { - if !mld.opts.Enabled { - return nil - } - - // As per RFC 2710 section 5 page 10, - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. - // - // This is equivalent to not performing MLD for the all-nodes multicast - // address. Simply not performing MLD when the group is added will prevent - // any work from being done on the all-nodes multicast group when leaving the - // group or when query or report messages are received for it since the MGP - // state will not know about it. - if groupAddress == header.IPv6AllNodesMulticastAddress { - return nil - } +func (mld *mldState) joinGroup(groupAddress tcpip.Address) { + mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) +} - // JoinGroup returns false if we have already joined the group. - if !mld.genericMulticastProtocol.JoinGroup(groupAddress) { - return tcpip.ErrDuplicateAddress - } - return nil +// isInGroup returns true if the specified group has been joined locally. +func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { + return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Done message, if // required. -// -// If the group is not joined, this function will do nothing. -func (mld *mldState) leaveGroup(groupAddress tcpip.Address) { - if !mld.opts.Enabled { - return +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + return nil } - mld.genericMulticastProtocol.LeaveGroup(groupAddress) + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of MLD, but remains +// joined locally. +func (mld *mldState) softLeaveAll() { + mld.genericMulticastProtocol.MakeAllNonMember() +} + +// initializeAll attemps to initialize the MLD state for each group that has +// been joined locally. +func (mld *mldState) initializeAll() { + mld.genericMulticastProtocol.InitializeGroups() } func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { |