diff options
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 145 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 88 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 20 |
7 files changed, 199 insertions, 157 deletions
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index e308550c4..bc7f7f637 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -138,76 +138,89 @@ type MulticastGroupProtocol interface { // IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state // machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. // +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + opts GenericMulticastProtocolOptions - mu struct { - sync.RWMutex + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - } + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex } // Init initializes the Generic Multicast Protocol state. -func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { - g.mu.Lock() - defer g.mu.Unlock() +// +// Must only be called once for the lifetime of g. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + g.opts = opts - g.mu.memberships = make(map[tcpip.Address]multicastGroupState) + g.memberships = make(map[tcpip.Address]multicastGroupState) + g.protocolMU = protocolMU } -// MakeAllNonMember transitions all groups to the non-member state. +// MakeAllNonMemberLocked transitions all groups to the non-member state. // // The groups will still be considered joined locally. -func (g *GenericMulticastProtocolState) MakeAllNonMember() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.transitionToNonMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// InitializeGroups initializes each group, as if they were newly joined but -// without affecting the groups' join count. +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. // // Must only be called after calling MakeAllNonMember as a group should not be // initialized while it is not in the non-member state. -func (g *GenericMulticastProtocolState) InitializeGroups() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.initializeNewMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// JoinGroup handles joining a new group. +// JoinGroupLocked handles joining a new group. // // If dontInitialize is true, the group will be not be initialized and will be // left in the non-member state - no packets will be sent for it until it is // initialized via InitializeGroups. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { - g.mu.Lock() - defer g.mu.Unlock() - - if info, ok := g.mu.memberships[groupAddress]; ok { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) { + if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return } @@ -217,15 +230,15 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do // The state will be updated below, if required. state: nonMember, lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { - info, ok := g.mu.memberships[groupAddress] + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info }), } @@ -233,25 +246,24 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do g.initializeNewMemberLocked(groupAddress, &info) } - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } -// IsLocallyJoined returns true if the group is locally joined. -func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { - g.mu.RLock() - defer g.mu.RUnlock() - _, ok := g.mu.memberships[groupAddress] +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] return ok } -// LeaveGroup handles leaving the group. +// LeaveGroupLocked handles leaving the group. // // Returns false if the group is not currently joined. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - - info, ok := g.mu.memberships[groupAddress] +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] if !ok { return false } @@ -262,30 +274,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b info.joins-- if info.joins != 0 { // If we still have outstanding joins, then do nothing further. - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return true } g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.mu.memberships, groupAddress) + delete(g.memberships, groupAddress) return true } -// HandleQuery handles a query message with the specified maximum response time. +// HandleQueryLocked handles a query message with the specified maximum response +// time. // // If the group address is unspecified, then reports will be scheduled for all // joined groups. // // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. -func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 2.4 (for IGMPv2), // // In a Membership Query message, the group address field is set to zero @@ -299,28 +311,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // when sending a Multicast-Address-Specific Query. if groupAddress.Unspecified() { // This is a general query as the group address is unspecified. - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } - } else if info, ok := g.mu.memberships[groupAddress]; ok { + } else if info, ok := g.memberships[groupAddress]; ok { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// HandleReport handles a report message. +// HandleReportLocked handles a report message. // // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. -func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), // // If the host receives another host's Report (version 1 or 2) while it has @@ -333,17 +344,17 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } // initializeNewMemberLocked initializes a new group membership. // -// Precondition: g.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) @@ -465,7 +476,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres // transitionToNonMemberLocked transitions the given multicast group the the // non-member/listener state. // -// Precondition: e.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state == nonMember { return @@ -479,7 +490,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress // setDelayTimerForAddressRLocked sets timer to send a delay report. // -// Precondition: g.mu MUST be read locked. +// Precondition: g.protocolMU MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { if info.state == nonMember { return diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 0134fadc0..2ee4c6445 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -16,7 +16,6 @@ package ipv4 import ( "fmt" - "sync" "sync/atomic" "time" @@ -68,8 +67,9 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil) // igmpState.init() MUST be called after creating an IGMP state. type igmpState struct { // The IPv4 endpoint this igmpState is for. - ep *endpoint - opts IGMPOptions + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 @@ -84,16 +84,10 @@ type igmpState struct { // when false. igmpV1Present uint32 - mu struct { - sync.RWMutex - - genericMulticastProtocol ip.GenericMulticastProtocolState - - // igmpV1Job is scheduled when this interface receives an IGMPv1 style - // message, upon expiration the igmpV1Present flag is cleared. - // igmpV1Job may not be nil once igmpState is initialized. - igmpV1Job *tcpip.Job - } + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job } // SendReport implements ip.MulticastGroupProtocol. @@ -119,13 +113,12 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // init sets up an igmpState struct, and is required to be called before using // a new igmpState. -func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { - igmp.mu.Lock() - defer igmp.mu.Unlock() +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { igmp.ep = ep - igmp.opts = opts - igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Enabled: ep.protocol.options.IGMP.Enabled, Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: igmp, @@ -133,11 +126,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault - igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { igmp.setV1Present(false) }) } +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { stats := igmp.ep.protocol.stack.Stats() received := stats.IGMP.PacketsReceived @@ -207,27 +203,28 @@ func (igmp *igmpState) setV1Present(v bool) { } } +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero // then change the state to note that an IGMPv1 router is present and // schedule the query received Job. - if maxRespTime == 0 && igmp.opts.Enabled { - igmp.mu.igmpV1Job.Cancel() - igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + if maxRespTime == 0 && igmp.ep.protocol.options.IGMP.Enabled { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) igmp.setV1Present(true) maxRespTime = v1MaxRespTime } - igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime) + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) } +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) } // writePacket assembles and sends an IGMP packet with the provided fields, @@ -278,28 +275,27 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // // If the group already exists in the membership map, returns // tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { - igmp.mu.Lock() - defer igmp.mu.Unlock() - return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Leave Group message // if required. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // LeaveGroup returns false only if the group was not joined. - if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) { + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -308,16 +304,16 @@ func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of IGMP, but remains // joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) softLeaveAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.MakeAllNonMember() + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the IGMP state for each group that has // been joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) initializeAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.InitializeGroups() + igmp.genericMulticastProtocol.InitializeGroupsLocked() } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 3076185cd..c63ecca4a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,7 +72,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -84,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.igmp.init(e, p.options.IGMP) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -127,7 +129,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of IGMP when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.igmp.initializeAll() + e.mu.igmp.initializeAll() // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the @@ -181,7 +183,7 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of IGMP so that routers know that // we are no longer interested in the group. - e.igmp.softLeaveAll() + e.mu.igmp.softLeaveAll() // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { @@ -718,7 +720,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } if p == header.IGMPProtocolNumber { - e.igmp.handleIGMP(pkt) + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() return } if opts := h.Options(); len(opts) != 0 { @@ -843,7 +847,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.igmp.joinGroup(addr) + e.mu.igmp.joinGroup(addr) return nil } @@ -858,14 +862,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.igmp.leaveGroup(addr) + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.igmp.isInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 510276b8e..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -645,26 +645,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: - var handler func(header.MLD) switch icmpType { case header.ICMPv6MulticastListenerQuery: received.MulticastListenerQuery.Increment() - handler = e.mld.handleMulticastListenerQuery case header.ICMPv6MulticastListenerReport: received.MulticastListenerReport.Increment() - handler = e.mld.handleMulticastListenerReport case header.ICMPv6MulticastListenerDone: received.MulticastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { received.Invalid.Increment() return } - if handler != nil { - handler(header.MLD(payload.ToView())) + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } default: diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 8bf84601f..7288e309c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -85,9 +85,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } - - mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -232,7 +231,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of MLD when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.mld.initializeAll() + e.mu.mld.initializeAll() // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives @@ -349,7 +348,7 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. - e.mld.softLeaveAll() + e.mu.mld.softLeaveAll() } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1417,7 +1416,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.mld.joinGroup(addr) + e.mu.mld.joinGroup(addr) return nil } @@ -1432,14 +1431,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.mld.leaveGroup(addr) + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mld.isInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -1504,17 +1503,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.options.NDPConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() - e.mld.init(e, p.options.MLD) + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 4c06b3f0c..6face17c6 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -67,10 +67,12 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // init sets up an mldState struct, and is required to be called before using // a new mldState. -func (mld *mldState) init(ep *endpoint, opts MLDOptions) { +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { mld.ep = ep - mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Enabled: ep.protocol.options.MLD.Enabled, Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: mld, @@ -79,33 +81,45 @@ func (mld *mldState) init(ep *endpoint, opts MLDOptions) { }) } +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) } +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) } // joinGroup handles joining a new group and sending and scheduling the required // messages. // // If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { - mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { - return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Done message, if // required. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // LeaveGroup returns false only if the group was not joined. - if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -114,14 +128,18 @@ func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of MLD, but remains // joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) softLeaveAll() { - mld.genericMulticastProtocol.MakeAllNonMember() + mld.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the MLD state for each group that has // been joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) initializeAll() { - mld.genericMulticastProtocol.InitializeGroups() + mld.genericMulticastProtocol.InitializeGroupsLocked() } func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8cb7d4dab..2f5e2e82c 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -1884,11 +1888,19 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } |